Mercurial > hg > smallbox
comparison DL/RLS-DLA/SMALL_rlsdla1.m @ 40:6416fc12f2b8
(none)
author | idamnjanovic |
---|---|
date | Mon, 14 Mar 2011 15:35:24 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
39:8f734534839a | 40:6416fc12f2b8 |
---|---|
1 function Dictionary = SMALL_rlsdla1(X, params) | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 global CODE_SPARSITY CODE_ERROR codemode | |
8 global MEM_LOW MEM_NORMAL MEM_HIGH memusage | |
9 global ompfunc ompparams exactsvd | |
10 | |
11 CODE_SPARSITY = 1; | |
12 CODE_ERROR = 2; | |
13 | |
14 MEM_LOW = 1; | |
15 MEM_NORMAL = 2; | |
16 MEM_HIGH = 3; | |
17 | |
18 | |
19 % p = randperm(size(X,2)); | |
20 | |
21 % coding mode % | |
22 X_norm=sqrt(sum(X.^2, 1)); | |
23 % X_norm_1=sum(abs(X)); | |
24 % X_norm_inf=max(abs(X)); | |
25 [X_norm_sort, p]=sort(X_norm);%, 'descend'); | |
26 % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); | |
27 | |
28 % if (isfield(params,'codemode')) | |
29 % switch lower(params.codemode) | |
30 % case 'sparsity' | |
31 % codemode = CODE_SPARSITY; | |
32 % thresh = params.Tdata; | |
33 % case 'error' | |
34 % codemode = CODE_ERROR; | |
35 % thresh = params.Edata; | |
36 % otherwise | |
37 % error('Invalid coding mode specified'); | |
38 % end | |
39 % elseif (isfield(params,'Tdata')) | |
40 % codemode = CODE_SPARSITY; | |
41 % thresh = params.Tdata; | |
42 % elseif (isfield(params,'Edata')) | |
43 % codemode = CODE_ERROR; | |
44 % thresh = params.Edata; | |
45 % | |
46 % else | |
47 % error('Data sparse-coding target not specified'); | |
48 % end | |
49 | |
50 thresh = params.Edata; | |
51 % max number of atoms % | |
52 | |
53 % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) | |
54 % ompparams{end+1} = 'maxatoms'; | |
55 % ompparams{end+1} = params.maxatoms; | |
56 % end | |
57 | |
58 | |
59 % memory usage % | |
60 | |
61 if (isfield(params,'memusage')) | |
62 switch lower(params.memusage) | |
63 case 'low' | |
64 memusage = MEM_LOW; | |
65 case 'normal' | |
66 memusage = MEM_NORMAL; | |
67 case 'high' | |
68 memusage = MEM_HIGH; | |
69 otherwise | |
70 error('Invalid memory usage mode'); | |
71 end | |
72 else | |
73 memusage = MEM_NORMAL; | |
74 end | |
75 | |
76 | |
77 % iteration count % | |
78 | |
79 if (isfield(params,'iternum')) | |
80 iternum = params.iternum; | |
81 else | |
82 iternum = 10; | |
83 end | |
84 | |
85 | |
86 % omp function % | |
87 | |
88 if (codemode == CODE_SPARSITY) | |
89 ompfunc = @omp; | |
90 else | |
91 ompfunc = @omp2; | |
92 end | |
93 | |
94 | |
95 % % status messages % | |
96 % | |
97 % printiter = 0; | |
98 % printreplaced = 0; | |
99 % printerr = 0; | |
100 % printgerr = 0; | |
101 % | |
102 % verbose = 't'; | |
103 % msgdelta = -1; | |
104 % | |
105 | |
106 % | |
107 % for i = 1:length(verbose) | |
108 % switch lower(verbose(i)) | |
109 % case 'i' | |
110 % printiter = 1; | |
111 % case 'r' | |
112 % printiter = 1; | |
113 % printreplaced = 1; | |
114 % case 't' | |
115 % printiter = 1; | |
116 % printerr = 1; | |
117 % if (isfield(params,'testdata')) | |
118 % printgerr = 1; | |
119 % end | |
120 % end | |
121 % end | |
122 % | |
123 % if (msgdelta<=0 || isempty(verbose)) | |
124 % msgdelta = -1; | |
125 % end | |
126 % | |
127 % ompparams{end+1} = 'messages'; | |
128 % ompparams{end+1} = msgdelta; | |
129 % | |
130 % | |
131 % | |
132 % % compute error flag % | |
133 % | |
134 % comperr = (nargout>=3 || printerr); | |
135 % | |
136 % | |
137 % % validation flag % | |
138 % | |
139 % testgen = 0; | |
140 % if (isfield(params,'testdata')) | |
141 % testdata = params.testdata; | |
142 % if (nargout>=4 || printgerr) | |
143 % testgen = 1; | |
144 % end | |
145 % end | |
146 | |
147 % | |
148 % % data norms % | |
149 % | |
150 % XtX = []; XtXg = []; | |
151 % if (codemode==CODE_ERROR && memusage==MEM_HIGH) | |
152 % XtX = colnorms_squared(data); | |
153 % if (testgen) | |
154 % XtXg = colnorms_squared(testdata); | |
155 % end | |
156 % end | |
157 | |
158 | |
159 % mutual incoherence limit % | |
160 | |
161 if (isfield(params,'muthresh')) | |
162 muthresh = params.muthresh; | |
163 else | |
164 muthresh = 0.99; | |
165 end | |
166 if (muthresh < 0) | |
167 error('invalid muthresh value, must be non-negative'); | |
168 end | |
169 | |
170 | |
171 | |
172 | |
173 | |
174 % determine dictionary size % | |
175 | |
176 if (isfield(params,'initdict')) | |
177 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
178 dictsize = length(params.initdict); | |
179 else | |
180 dictsize = size(params.initdict,2); | |
181 end | |
182 end | |
183 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict | |
184 dictsize = params.dictsize; | |
185 end | |
186 | |
187 if (size(X,2) < dictsize) | |
188 error('Number of training signals is smaller than number of atoms to train'); | |
189 end | |
190 | |
191 | |
192 % initialize the dictionary % | |
193 | |
194 if (isfield(params,'initdict')) | |
195 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
196 D1 = X(:,params.initdict(1:dictsize)); | |
197 D2 = X(:,params.initdict(1:dictsize)); | |
198 else | |
199 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) | |
200 error('Invalid initial dictionary'); | |
201 end | |
202 D1 = params.initdict(:,1:dictsize); | |
203 D2 = params.initdict(:,1:dictsize); | |
204 end | |
205 else | |
206 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
207 perm = randperm(length(data_ids)); | |
208 D = X(:,data_ids(perm(1:dictsize))); | |
209 end | |
210 | |
211 | |
212 % normalize the dictionary % | |
213 | |
214 % D = normcols(D); | |
215 % DtD=D'*D; | |
216 | |
217 err = zeros(1,iternum); | |
218 gerr = zeros(1,iternum); | |
219 | |
220 if (codemode == CODE_SPARSITY) | |
221 errstr = 'RMSE'; | |
222 else | |
223 errstr = 'mean atomnum'; | |
224 end | |
225 X(:,p(X_norm_sort<thresh))=0; | |
226 % if (iternum==4) | |
227 % X_im=col2imstep(X, [256 256], [8 8]); | |
228 % else | |
229 % X_im=col2imstep(X, [512 512], [8 8]); | |
230 % end | |
231 % figure(10); imshow(X_im); | |
232 p1=p(X_norm_sort>thresh); | |
233 | |
234 %p1=p1(p2(1:40000)); | |
235 %end-min(40000, end)+1:end));%1:min(40000, end))); | |
236 %p1 = randperm(size(data,2));%size(data,2) | |
237 %data=data(:,p1); | |
238 | |
239 C1=(100000*thresh)*eye(dictsize); | |
240 C2=(100000*thresh)*eye(dictsize); | |
241 % figure(11); | |
242 w=zeros(dictsize,1); | |
243 replaced=zeros(dictsize,1); | |
244 u=zeros(dictsize,1); | |
245 dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast'); | |
246 % h=imshow(imresize(dictimg,2,'nearest')); | |
247 lambda=0.9998 | |
248 for j=1:3 | |
249 if size(p1,2)>20000 | |
250 p2 = randperm(floor(size(p1,2)/2)); | |
251 p2=sort(p2(1:20000)); | |
252 data1=X(:,p1(p2)); | |
253 data2=X(:,p1(floor(size(p1,2)/2)+p2)); | |
254 elseif size(p1,2)>0 | |
255 data=X(:,p1); | |
256 else | |
257 break; | |
258 end | |
259 % figure(1); | |
260 % plot(sqrt(sum(data.^2, 1))); | |
261 % a=size(data,2)/4; | |
262 % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 | |
263 C1(1,1)=0; | |
264 C2(1,1)=0; | |
265 for i = 1:size(data1,2) | |
266 % if norm(data(:,i))>thresh | |
267 % par.multA= @(x,par) multMatr(D,x); % user function y=Ax | |
268 % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x | |
269 % par.y=data(:,i); | |
270 % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); | |
271 % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); | |
272 %w = SMALL_chol(D,data(:,i), 256,32, thresh);% | |
273 %w = sparsecode(data(:,i), D, [], [], thresh); | |
274 w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0); | |
275 w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0); | |
276 %w(find(w<1))=0; | |
277 %^2; | |
278 % lambda(i)=1-0.001/(1+i/a); | |
279 % if i<a | |
280 % lambda(i)=1-0.001*(1-(i/a)); | |
281 % else | |
282 % lambda(i)=1; | |
283 % end | |
284 % param.lambda=thresh; | |
285 % param.mode=2; | |
286 % param.L=32; | |
287 % w=mexLasso(data(:,i), D, param); | |
288 spind1=find(w1); | |
289 spind2=find(w2); | |
290 | |
291 %replaced(spind)=replaced(spind)+1; | |
292 %-0.001*(1/2)^(i/a); | |
293 % w_sp(i)=nnz(w); | |
294 residual1 = data1(:,i) - D1 * w1; | |
295 residual2 = data2(:,i) - D2 * w2; | |
296 %if ~isempty(spind) | |
297 %i | |
298 | |
299 C1 = C1 *(1/ lambda); | |
300 C2 = C2 *(1/ lambda); | |
301 u1 = C1(:,spind1) * w1(spind1); | |
302 u2 = C2(:,spind2) * w2(spind2); | |
303 %spindu=find(u); | |
304 % v = D' * residual; | |
305 | |
306 alfa1 = 1/(1 + w1' * u1); | |
307 alfa2 = 1/(1 + w2' * u2); | |
308 D1 = D1 + (alfa1 * residual1) * u1'; | |
309 D2 = D2 + (alfa2 * residual2) * u2'; | |
310 %uut=; | |
311 C1 = C1 - (alfa1 * u1)* u1'; | |
312 C2 = C2 - (alfa2 * u2)* u2'; | |
313 % lambda=(19*lambda+1)/20; | |
314 % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; | |
315 % modi=5000; | |
316 % if (mod(i,modi)==0) | |
317 % Ximd=zeros(size(X)); | |
318 % Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i); | |
319 % | |
320 % if (iternum==4) | |
321 % X_ima=col2imstep(Ximd, [256 256], [8 8]); | |
322 % else | |
323 % X_ima=col2imstep(Ximd, [512 512], [8 8]); | |
324 % end | |
325 % dictimg1=dictimg; | |
326 % dictimg = showdict(D,[8 8],... | |
327 % round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
328 % dictimg1=(dictimg-dictimg1)*255; | |
329 % | |
330 % figure(2); | |
331 % subplot(2,2,1); imshow(X_ima); | |
332 % subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); | |
333 % subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); | |
334 % subplot(2,2,2);imshow(C*(255/max(max(C)))); | |
335 % pause(0.02); | |
336 % end | |
337 % end | |
338 end | |
339 %p1=p1(setxor(p2,1:end)); | |
340 %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); | |
341 %replaced=zeros(dictsize,1); | |
342 % W=sparsecode(data, D, [], [], thresh); | |
343 % data=D*W; | |
344 lambda=lambda+0.0001 | |
345 end | |
346 %Gamma=mexLasso(data, D, param); | |
347 %err=compute_err(D,Gamma, data); | |
348 %[y,i]=max(err); | |
349 %D(:,1)=data(:,i)/norm(data(:,i)); | |
350 % D=normcols(D); | |
351 % D_norm=sqrt(sum(D.^2, 1)); | |
352 % D_norm_1=sum(abs(D)); | |
353 % X_norm_1=sum(abs(X)); | |
354 % X_norm_inf=max(abs(X)); | |
355 % [D_norm_sort, p]=sort(D_norm_1, 'descend'); | |
356 Dictionary =[D1 D2]; | |
357 % figure(3); | |
358 % plot(lambda); | |
359 % mean(lambda); | |
360 % figure(4+j);plot(w_sp); | |
361 end | |
362 | |
363 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
364 % sparsecode % | |
365 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
366 | |
367 function Gamma = sparsecode(data,D,XtX,G,thresh) | |
368 | |
369 global CODE_SPARSITY codemode | |
370 global MEM_HIGH memusage | |
371 global ompfunc ompparams | |
372 | |
373 if (memusage < MEM_HIGH) | |
374 Gamma = ompfunc(D,data,G,thresh,ompparams{:}); | |
375 | |
376 else % memusage is high | |
377 | |
378 if (codemode == CODE_SPARSITY) | |
379 Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); | |
380 | |
381 else | |
382 Gamma = ompfunc(D, data, G, thresh,ompparams{:}); | |
383 end | |
384 | |
385 end | |
386 | |
387 end | |
388 | |
389 | |
390 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
391 % compute_err % | |
392 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
393 | |
394 | |
395 function err = compute_err(D,Gamma,data) | |
396 | |
397 global CODE_SPARSITY codemode | |
398 | |
399 if (codemode == CODE_SPARSITY) | |
400 err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); | |
401 else | |
402 err = nnz(Gamma)/size(data,2); | |
403 end | |
404 | |
405 end | |
406 | |
407 | |
408 | |
409 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
410 % cleardict % | |
411 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
412 | |
413 | |
414 function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) | |
415 | |
416 use_thresh = 4; % at least this number of samples must use the atom to be kept | |
417 | |
418 dictsize = size(D,2); | |
419 | |
420 % compute error in blocks to conserve memory | |
421 % err = zeros(1,size(X,2)); | |
422 % blocks = [1:3000:size(X,2) size(X,2)+1]; | |
423 % for i = 1:length(blocks)-1 | |
424 % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); | |
425 % end | |
426 | |
427 cleared_atoms = 0; | |
428 usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); | |
429 | |
430 for j = 1:dictsize | |
431 | |
432 % compute G(:,j) | |
433 Gj = D'*D(:,j); | |
434 Gj(j) = 0; | |
435 | |
436 % replace atom | |
437 if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) | |
438 % [y,i] = max(err(unused_sigs)); | |
439 D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); | |
440 unused_sigs = unused_sigs([1:end-1]); | |
441 cleared_atoms = cleared_atoms+1; | |
442 end | |
443 end | |
444 | |
445 end | |
446 | |
447 | |
448 | |
449 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
450 % misc functions % | |
451 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
452 | |
453 | |
454 function err2 = reperror2(X,D,Gamma) | |
455 | |
456 % compute in blocks to conserve memory | |
457 err2 = zeros(1,size(X,2)); | |
458 blocksize = 2000; | |
459 for i = 1:blocksize:size(X,2) | |
460 blockids = i : min(i+blocksize-1,size(X,2)); | |
461 err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); | |
462 end | |
463 | |
464 end | |
465 | |
466 | |
467 function Y = colnorms_squared(X) | |
468 | |
469 % compute in blocks to conserve memory | |
470 Y = zeros(1,size(X,2)); | |
471 blocksize = 2000; | |
472 for i = 1:blocksize:size(X,2) | |
473 blockids = i : min(i+blocksize-1,size(X,2)); | |
474 Y(blockids) = sum(X(:,blockids).^2); | |
475 end | |
476 | |
477 end | |
478 | |
479 |