Mercurial > hg > smallbox
comparison DL/RLS-DLA/SMALL_rlsdla 05032011.m @ 40:6416fc12f2b8
(none)
author | idamnjanovic |
---|---|
date | Mon, 14 Mar 2011 15:35:24 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
39:8f734534839a | 40:6416fc12f2b8 |
---|---|
1 function Dictionary = SMALL_rlsdla(X, params) | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 global CODE_SPARSITY CODE_ERROR codemode | |
8 global MEM_LOW MEM_NORMAL MEM_HIGH memusage | |
9 global ompfunc ompparams exactsvd | |
10 | |
11 CODE_SPARSITY = 1; | |
12 CODE_ERROR = 2; | |
13 | |
14 MEM_LOW = 1; | |
15 MEM_NORMAL = 2; | |
16 MEM_HIGH = 3; | |
17 | |
18 | |
19 % p = randperm(size(X,2)); | |
20 | |
21 % coding mode % | |
22 %X_norm=sqrt(sum(X.^2, 1)); | |
23 % X_norm_1=sum(abs(X)); | |
24 % X_norm_inf=max(abs(X)); | |
25 %[X_norm_sort, p]=sort(X_norm);%, 'descend'); | |
26 % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); | |
27 | |
28 % if (isfield(params,'codemode')) | |
29 % switch lower(params.codemode) | |
30 % case 'sparsity' | |
31 % codemode = CODE_SPARSITY; | |
32 % thresh = params.Tdata; | |
33 % case 'error' | |
34 % codemode = CODE_ERROR; | |
35 % thresh = params.Edata; | |
36 % otherwise | |
37 % error('Invalid coding mode specified'); | |
38 % end | |
39 % elseif (isfield(params,'Tdata')) | |
40 % codemode = CODE_SPARSITY; | |
41 % thresh = params.Tdata; | |
42 % elseif (isfield(params,'Edata')) | |
43 % codemode = CODE_ERROR; | |
44 % thresh = params.Edata; | |
45 % | |
46 % else | |
47 % error('Data sparse-coding target not specified'); | |
48 % end | |
49 | |
50 thresh = params.Edata; | |
51 % max number of atoms % | |
52 | |
53 % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) | |
54 % ompparams{end+1} = 'maxatoms'; | |
55 % ompparams{end+1} = params.maxatoms; | |
56 % end | |
57 | |
58 | |
59 % memory usage % | |
60 | |
61 if (isfield(params,'memusage')) | |
62 switch lower(params.memusage) | |
63 case 'low' | |
64 memusage = MEM_LOW; | |
65 case 'normal' | |
66 memusage = MEM_NORMAL; | |
67 case 'high' | |
68 memusage = MEM_HIGH; | |
69 otherwise | |
70 error('Invalid memory usage mode'); | |
71 end | |
72 else | |
73 memusage = MEM_NORMAL; | |
74 end | |
75 | |
76 | |
77 % iteration count % | |
78 | |
79 if (isfield(params,'iternum')) | |
80 iternum = params.iternum; | |
81 else | |
82 iternum = 10; | |
83 end | |
84 | |
85 | |
86 % omp function % | |
87 | |
88 if (codemode == CODE_SPARSITY) | |
89 ompfunc = @omp; | |
90 else | |
91 ompfunc = @omp2; | |
92 end | |
93 | |
94 | |
95 % % status messages % | |
96 % | |
97 % printiter = 0; | |
98 % printreplaced = 0; | |
99 % printerr = 0; | |
100 % printgerr = 0; | |
101 % | |
102 % verbose = 't'; | |
103 % msgdelta = -1; | |
104 % | |
105 | |
106 % | |
107 % for i = 1:length(verbose) | |
108 % switch lower(verbose(i)) | |
109 % case 'i' | |
110 % printiter = 1; | |
111 % case 'r' | |
112 % printiter = 1; | |
113 % printreplaced = 1; | |
114 % case 't' | |
115 % printiter = 1; | |
116 % printerr = 1; | |
117 % if (isfield(params,'testdata')) | |
118 % printgerr = 1; | |
119 % end | |
120 % end | |
121 % end | |
122 % | |
123 % if (msgdelta<=0 || isempty(verbose)) | |
124 % msgdelta = -1; | |
125 % end | |
126 % | |
127 % ompparams{end+1} = 'messages'; | |
128 % ompparams{end+1} = msgdelta; | |
129 % | |
130 % | |
131 % | |
132 % % compute error flag % | |
133 % | |
134 % comperr = (nargout>=3 || printerr); | |
135 % | |
136 % | |
137 % % validation flag % | |
138 % | |
139 % testgen = 0; | |
140 % if (isfield(params,'testdata')) | |
141 % testdata = params.testdata; | |
142 % if (nargout>=4 || printgerr) | |
143 % testgen = 1; | |
144 % end | |
145 % end | |
146 | |
147 % | |
148 % % data norms % | |
149 % | |
150 % XtX = []; XtXg = []; | |
151 % if (codemode==CODE_ERROR && memusage==MEM_HIGH) | |
152 % XtX = colnorms_squared(data); | |
153 % if (testgen) | |
154 % XtXg = colnorms_squared(testdata); | |
155 % end | |
156 % end | |
157 | |
158 | |
159 % mutual incoherence limit % | |
160 | |
161 if (isfield(params,'muthresh')) | |
162 muthresh = params.muthresh; | |
163 else | |
164 muthresh = 0.99; | |
165 end | |
166 if (muthresh < 0) | |
167 error('invalid muthresh value, must be non-negative'); | |
168 end | |
169 | |
170 | |
171 | |
172 | |
173 | |
174 % determine dictionary size % | |
175 | |
176 if (isfield(params,'initdict')) | |
177 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
178 dictsize = length(params.initdict); | |
179 else | |
180 dictsize = size(params.initdict,2); | |
181 end | |
182 end | |
183 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict | |
184 dictsize = params.dictsize; | |
185 end | |
186 | |
187 if (size(X,2) < dictsize) | |
188 error('Number of training signals is smaller than number of atoms to train'); | |
189 end | |
190 | |
191 | |
192 % initialize the dictionary % | |
193 | |
194 if (isfield(params,'initdict')) | |
195 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
196 D = X(:,params.initdict(1:dictsize)); | |
197 else | |
198 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) | |
199 error('Invalid initial dictionary'); | |
200 end | |
201 D = params.initdict(:,1:dictsize); | |
202 end | |
203 else | |
204 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
205 perm = randperm(length(data_ids)); | |
206 D = X(:,data_ids(perm(1:dictsize))); | |
207 end | |
208 | |
209 | |
210 % normalize the dictionary % | |
211 | |
212 % D = normcols(D); | |
213 % DtD=D'*D; | |
214 | |
215 err = zeros(1,iternum); | |
216 gerr = zeros(1,iternum); | |
217 | |
218 if (codemode == CODE_SPARSITY) | |
219 errstr = 'RMSE'; | |
220 else | |
221 errstr = 'mean atomnum'; | |
222 end | |
223 %X(:,p(X_norm_sort<thresh))=0; | |
224 % if (iternum==4) | |
225 % X_im=col2imstep(X, [256 256], [8 8]); | |
226 % else | |
227 % X_im=col2imstep(X, [512 512], [8 8]); | |
228 % end | |
229 % figure(10); imshow(X_im); | |
230 | |
231 %p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); | |
232 %p1=p(X_norm_sort>thresh); | |
233 % X(:,setxor(p1,1:end))=0; | |
234 % X_im=col2imstep(X, [256 256], [8 8]); | |
235 % figure(10); imshow(X_im); | |
236 % if iternum==2 | |
237 % D(:,1)=D(:,2); | |
238 % end | |
239 %p1=p1(p2(1:40000)); | |
240 %end-min(40000, end)+1:end));%1:min(40000, end))); | |
241 %p1 = randperm(size(data,2));%size(data,2) | |
242 %data=data(:,p1); | |
243 | |
244 C=(100000*thresh)*eye(dictsize); | |
245 % figure(11); | |
246 w=zeros(dictsize,1); | |
247 replaced=zeros(dictsize,1); | |
248 u=zeros(dictsize,1); | |
249 % dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
250 % figure(11);imshow(imresize(dictimg,2,'nearest')); | |
251 % pause(1); | |
252 lambda=0.9997%0.99986;%3+0.0001*params.linc; | |
253 for j=1:1 | |
254 %data=X; | |
255 if size(X,2)>40000 | |
256 p2 = randperm(size(X,2)); | |
257 | |
258 p2=sort(p2(1:40000));%min(floor(size(p1,2)/2),40000))); | |
259 size(p2,2) | |
260 data=X(:,p2); | |
261 elseif size(X,2)>0 | |
262 %p2 = randperm(size(p1,2)); | |
263 size(X,2) | |
264 data=X; | |
265 else | |
266 break; | |
267 end | |
268 % figure(1); | |
269 % plot(sqrt(sum(data.^2, 1))); | |
270 % a=size(data,2)/4; | |
271 % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 | |
272 %C(1,1)=0; | |
273 modi=1000; | |
274 for i = 1:size(data,2) | |
275 % if norm(data(:,i))>thresh | |
276 % par.multA= @(x,par) multMatr(D,x); % user function y=Ax | |
277 % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x | |
278 % par.y=data(:,i); | |
279 % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); | |
280 % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); | |
281 %w = SMALL_chol(D,data(:,i), 256,32, thresh);% | |
282 %w = sparsecode(data(:,i), D, [], [], thresh); | |
283 w = omp2mex(D,data(:,i),[],[],[],thresh,0,-1,-1,0); | |
284 | |
285 %w(find(w<1))=0; | |
286 %^2; | |
287 % lambda(i)=1-0.001/(1+i/a); | |
288 % if i<a | |
289 % lambda(i)=1-0.001*(1-(i/a)); | |
290 % else | |
291 % lambda(i)=1; | |
292 % end | |
293 % param.lambda=thresh; | |
294 % param.mode=2; | |
295 % param.L=32; | |
296 % w=mexLasso(data(:,i), D, param); | |
297 spind=find(w); | |
298 %replaced(spind)=replaced(spind)+1; | |
299 %-0.001*(1/2)^(i/a); | |
300 % w_sp(i)=nnz(w); | |
301 residual = data(:,i) - D * w; | |
302 %if ~isempty(spind) | |
303 %i | |
304 if (j==1) | |
305 C = C *(1/ lambda); | |
306 end | |
307 u = C(:,spind) * w(spind); | |
308 | |
309 %spindu=find(u); | |
310 % v = D' * residual; | |
311 | |
312 alfa = 1/(1 + w' * u); | |
313 | |
314 D = D + (alfa * residual) * u'; | |
315 | |
316 %uut=; | |
317 C = C - (alfa * u)* u'; | |
318 % lambda=(19*lambda+1)/20; | |
319 % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; | |
320 | |
321 % if (mod(i,modi)==0) | |
322 % Ximd=zeros(size(X)); | |
323 % Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); | |
324 % | |
325 % if (iternum==4) | |
326 % X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); | |
327 % X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); | |
328 % X_ima(:,:,3)=zeros(256,256); | |
329 % else | |
330 % X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); | |
331 % X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); | |
332 % X_ima(:,:,3)=zeros(512,512); | |
333 % end | |
334 % | |
335 % dictimg1=dictimg; | |
336 % dictimg = showdict(D,[8 8],... | |
337 % round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
338 % dictimg1=(dictimg-dictimg1); | |
339 % | |
340 % figure(2); | |
341 % subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); | |
342 % subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); | |
343 % subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); | |
344 % subplot(2,2,2);imshow(C*(255/max(max(C)))); | |
345 % pause(0.02); | |
346 % if (i>=35000) | |
347 % modi=100; | |
348 % pause | |
349 % end; | |
350 % end | |
351 % end | |
352 end | |
353 %p1=p1(setxor(p2,1:end)); | |
354 %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); | |
355 %replaced=zeros(dictsize,1); | |
356 % W=sparsecode(data, D, [], [], thresh); | |
357 % data=D*W; | |
358 %lambda=lambda+0.0002 | |
359 end | |
360 %Gamma=mexLasso(data, D, param); | |
361 %err=compute_err(D,Gamma, data); | |
362 %[y,i]=max(err); | |
363 %D(:,1)=data(:,i)/norm(data(:,i)); | |
364 | |
365 Dictionary = D;%D(:,p); | |
366 % figure(3); | |
367 % plot(lambda); | |
368 % mean(lambda); | |
369 % figure(4+j);plot(w_sp); | |
370 end | |
371 | |
372 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
373 % sparsecode % | |
374 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
375 | |
376 function Gamma = sparsecode(data,D,XtX,G,thresh) | |
377 | |
378 global CODE_SPARSITY codemode | |
379 global MEM_HIGH memusage | |
380 global ompfunc ompparams | |
381 | |
382 if (memusage < MEM_HIGH) | |
383 Gamma = ompfunc(D,data,G,thresh,ompparams{:}); | |
384 | |
385 else % memusage is high | |
386 | |
387 if (codemode == CODE_SPARSITY) | |
388 Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); | |
389 | |
390 else | |
391 Gamma = ompfunc(D, data, G, thresh,ompparams{:}); | |
392 end | |
393 | |
394 end | |
395 | |
396 end | |
397 | |
398 | |
399 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
400 % compute_err % | |
401 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
402 | |
403 | |
404 function err = compute_err(D,Gamma,data) | |
405 | |
406 global CODE_SPARSITY codemode | |
407 | |
408 if (codemode == CODE_SPARSITY) | |
409 err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); | |
410 else | |
411 err = nnz(Gamma)/size(data,2); | |
412 end | |
413 | |
414 end | |
415 | |
416 | |
417 | |
418 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
419 % cleardict % | |
420 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
421 | |
422 | |
423 function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) | |
424 | |
425 use_thresh = 4; % at least this number of samples must use the atom to be kept | |
426 | |
427 dictsize = size(D,2); | |
428 | |
429 % compute error in blocks to conserve memory | |
430 % err = zeros(1,size(X,2)); | |
431 % blocks = [1:3000:size(X,2) size(X,2)+1]; | |
432 % for i = 1:length(blocks)-1 | |
433 % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); | |
434 % end | |
435 | |
436 cleared_atoms = 0; | |
437 usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); | |
438 | |
439 for j = 1:dictsize | |
440 | |
441 % compute G(:,j) | |
442 Gj = D'*D(:,j); | |
443 Gj(j) = 0; | |
444 | |
445 % replace atom | |
446 if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) | |
447 % [y,i] = max(err(unused_sigs)); | |
448 D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); | |
449 unused_sigs = unused_sigs([1:end-1]); | |
450 cleared_atoms = cleared_atoms+1; | |
451 end | |
452 end | |
453 | |
454 end | |
455 | |
456 | |
457 | |
458 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
459 % misc functions % | |
460 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
461 | |
462 | |
463 function err2 = reperror2(X,D,Gamma) | |
464 | |
465 % compute in blocks to conserve memory | |
466 err2 = zeros(1,size(X,2)); | |
467 blocksize = 2000; | |
468 for i = 1:blocksize:size(X,2) | |
469 blockids = i : min(i+blocksize-1,size(X,2)); | |
470 err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); | |
471 end | |
472 | |
473 end | |
474 | |
475 | |
476 function Y = colnorms_squared(X) | |
477 | |
478 % compute in blocks to conserve memory | |
479 Y = zeros(1,size(X,2)); | |
480 blocksize = 2000; | |
481 for i = 1:blocksize:size(X,2) | |
482 blockids = i : min(i+blocksize-1,size(X,2)); | |
483 Y(blockids) = sum(X(:,blockids).^2); | |
484 end | |
485 | |
486 end | |
487 | |
488 |