Mercurial > hg > smallbox
comparison DL/RLS-DLA/SMALL_rlsdlaFirstClustTry.m @ 40:6416fc12f2b8
(none)
author | idamnjanovic |
---|---|
date | Mon, 14 Mar 2011 15:35:24 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
39:8f734534839a | 40:6416fc12f2b8 |
---|---|
1 function Dictionary = SMALL_rlsdla(X, params) | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 global CODE_SPARSITY CODE_ERROR codemode | |
8 global MEM_LOW MEM_NORMAL MEM_HIGH memusage | |
9 global ompfunc ompparams exactsvd | |
10 | |
11 CODE_SPARSITY = 1; | |
12 CODE_ERROR = 2; | |
13 | |
14 MEM_LOW = 1; | |
15 MEM_NORMAL = 2; | |
16 MEM_HIGH = 3; | |
17 | |
18 | |
19 % p = randperm(size(X,2)); | |
20 | |
21 % coding mode % | |
22 X_norm=sqrt(sum(X.^2, 1)); | |
23 % X_norm_1=sum(abs(X)); | |
24 % X_norm_inf=max(abs(X)); | |
25 [X_norm_sort, p]=sort(X_norm);%, 'descend'); | |
26 % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend'); | |
27 | |
28 % if (isfield(params,'codemode')) | |
29 % switch lower(params.codemode) | |
30 % case 'sparsity' | |
31 % codemode = CODE_SPARSITY; | |
32 % thresh = params.Tdata; | |
33 % case 'error' | |
34 % codemode = CODE_ERROR; | |
35 % thresh = params.Edata; | |
36 % otherwise | |
37 % error('Invalid coding mode specified'); | |
38 % end | |
39 % elseif (isfield(params,'Tdata')) | |
40 % codemode = CODE_SPARSITY; | |
41 % thresh = params.Tdata; | |
42 % elseif (isfield(params,'Edata')) | |
43 % codemode = CODE_ERROR; | |
44 % thresh = params.Edata; | |
45 % | |
46 % else | |
47 % error('Data sparse-coding target not specified'); | |
48 % end | |
49 | |
50 thresh = params.Edata; | |
51 % max number of atoms % | |
52 | |
53 % if (codemode==CODE_ERROR && isfield(params,'maxatoms')) | |
54 % ompparams{end+1} = 'maxatoms'; | |
55 % ompparams{end+1} = params.maxatoms; | |
56 % end | |
57 | |
58 | |
59 % memory usage % | |
60 | |
61 if (isfield(params,'memusage')) | |
62 switch lower(params.memusage) | |
63 case 'low' | |
64 memusage = MEM_LOW; | |
65 case 'normal' | |
66 memusage = MEM_NORMAL; | |
67 case 'high' | |
68 memusage = MEM_HIGH; | |
69 otherwise | |
70 error('Invalid memory usage mode'); | |
71 end | |
72 else | |
73 memusage = MEM_NORMAL; | |
74 end | |
75 | |
76 | |
77 % iteration count % | |
78 | |
79 if (isfield(params,'iternum')) | |
80 iternum = params.iternum; | |
81 else | |
82 iternum = 10; | |
83 end | |
84 | |
85 | |
86 % omp function % | |
87 | |
88 if (codemode == CODE_SPARSITY) | |
89 ompfunc = @omp; | |
90 else | |
91 ompfunc = @omp2; | |
92 end | |
93 | |
94 | |
95 % % status messages % | |
96 % | |
97 % printiter = 0; | |
98 % printreplaced = 0; | |
99 % printerr = 0; | |
100 % printgerr = 0; | |
101 % | |
102 % verbose = 't'; | |
103 % msgdelta = -1; | |
104 % | |
105 | |
106 % | |
107 % for i = 1:length(verbose) | |
108 % switch lower(verbose(i)) | |
109 % case 'i' | |
110 % printiter = 1; | |
111 % case 'r' | |
112 % printiter = 1; | |
113 % printreplaced = 1; | |
114 % case 't' | |
115 % printiter = 1; | |
116 % printerr = 1; | |
117 % if (isfield(params,'testdata')) | |
118 % printgerr = 1; | |
119 % end | |
120 % end | |
121 % end | |
122 % | |
123 % if (msgdelta<=0 || isempty(verbose)) | |
124 % msgdelta = -1; | |
125 % end | |
126 % | |
127 % ompparams{end+1} = 'messages'; | |
128 % ompparams{end+1} = msgdelta; | |
129 % | |
130 % | |
131 % | |
132 % % compute error flag % | |
133 % | |
134 % comperr = (nargout>=3 || printerr); | |
135 % | |
136 % | |
137 % % validation flag % | |
138 % | |
139 % testgen = 0; | |
140 % if (isfield(params,'testdata')) | |
141 % testdata = params.testdata; | |
142 % if (nargout>=4 || printgerr) | |
143 % testgen = 1; | |
144 % end | |
145 % end | |
146 | |
147 % | |
148 % % data norms % | |
149 % | |
150 % XtX = []; XtXg = []; | |
151 % if (codemode==CODE_ERROR && memusage==MEM_HIGH) | |
152 % XtX = colnorms_squared(data); | |
153 % if (testgen) | |
154 % XtXg = colnorms_squared(testdata); | |
155 % end | |
156 % end | |
157 | |
158 | |
159 % mutual incoherence limit % | |
160 | |
161 if (isfield(params,'muthresh')) | |
162 muthresh = params.muthresh; | |
163 else | |
164 muthresh = 0.99; | |
165 end | |
166 if (muthresh < 0) | |
167 error('invalid muthresh value, must be non-negative'); | |
168 end | |
169 | |
170 | |
171 | |
172 | |
173 | |
174 % determine dictionary size % | |
175 | |
176 if (isfield(params,'initdict')) | |
177 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
178 dictsize = length(params.initdict); | |
179 else | |
180 dictsize = size(params.initdict,2); | |
181 end | |
182 end | |
183 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict | |
184 dictsize = params.dictsize; | |
185 end | |
186 | |
187 if (size(X,2) < dictsize) | |
188 error('Number of training signals is smaller than number of atoms to train'); | |
189 end | |
190 | |
191 | |
192 % initialize the dictionary % | |
193 | |
194 if (isfield(params,'initdict')) | |
195 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:)))) | |
196 D = X(:,params.initdict(1:dictsize)); | |
197 else | |
198 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize) | |
199 error('Invalid initial dictionary'); | |
200 end | |
201 D = params.initdict(:,1:dictsize); | |
202 end | |
203 else | |
204 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen | |
205 perm = randperm(length(data_ids)); | |
206 D = X(:,data_ids(perm(1:dictsize))); | |
207 end | |
208 | |
209 % normalize the dictionary % | |
210 | |
211 % D = normcols(D); | |
212 % DtD=D'*D; | |
213 | |
214 err = zeros(1,iternum); | |
215 gerr = zeros(1,iternum); | |
216 | |
217 if (codemode == CODE_SPARSITY) | |
218 errstr = 'RMSE'; | |
219 else | |
220 errstr = 'mean atomnum'; | |
221 end | |
222 %X(:,p(X_norm_sort<thresh))=0; | |
223 % if (iternum==4) | |
224 % X_im=col2imstep(X, [256 256], [8 8]); | |
225 % else | |
226 % X_im=col2imstep(X, [512 512], [8 8]); | |
227 % end | |
228 % figure(10); imshow(X_im); | |
229 | |
230 %p1=p(cumsum(X_norm_sort)./[1:size(X_norm_sort,2)]>thresh); | |
231 p1=p(X_norm_sort>thresh); | |
232 tic; idx=kmeans(X(:,p1)',4, 'Start', 'cluster','MaxIter',200); toc | |
233 D=[D D D D]; | |
234 dictsize1=4*dictsize; | |
235 % X(:,setxor(p1,1:end))=0; | |
236 % X_im=col2imstep(X, [256 256], [8 8]); | |
237 % figure(10); imshow(X_im); | |
238 % if iternum==2 | |
239 % D(:,1)=D(:,2); | |
240 % end | |
241 %p1=p1(p2(1:40000)); | |
242 %end-min(40000, end)+1:end));%1:min(40000, end))); | |
243 %p1 = randperm(size(data,2));%size(data,2) | |
244 %data=data(:,p1); | |
245 | |
246 C=(100000*thresh)*eye(dictsize1); | |
247 % figure(11); | |
248 w=zeros(dictsize,1); | |
249 replaced=zeros(dictsize,1); | |
250 u=zeros(dictsize,1); | |
251 % dictimg = showdict(D,[8 8],round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
252 % figure(11);imshow(imresize(dictimg,2,'nearest')); | |
253 % pause(1); | |
254 lambda=0.99986;%3+0.0001*params.linc; | |
255 for j=1:1 | |
256 if size(p1,2)>60000 | |
257 p2 = randperm(size(p1,2)); | |
258 | |
259 p2=sort(p2(1:60000));%min(floor(size(p1,2)/2),40000))); | |
260 size(p2,2) | |
261 data=X(:,p1(p2)); | |
262 elseif size(p1,2)>0 | |
263 p2 = randperm(size(p1,2)); | |
264 size(p2,2) | |
265 data=X(:,p1); | |
266 else | |
267 break; | |
268 end | |
269 % figure(1); | |
270 % plot(sqrt(sum(data.^2, 1))); | |
271 % a=size(data,2)/4; | |
272 % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002 | |
273 %C(1,1)=0; | |
274 modi=1000; | |
275 for i = 1:size(data,2) | |
276 % if norm(data(:,i))>thresh | |
277 % par.multA= @(x,par) multMatr(D,x); % user function y=Ax | |
278 % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x | |
279 % par.y=data(:,i); | |
280 % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh); | |
281 % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par); | |
282 %w = SMALL_chol(D,data(:,i), 256,32, thresh);% | |
283 %w = sparsecode(data(:,i), D, [], [], thresh); | |
284 w = omp2mex(D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize),data(:,i),[],[],[],thresh,0,-1,-1,0); | |
285 | |
286 %w(find(w<1))=0; | |
287 %^2; | |
288 % lambda(i)=1-0.001/(1+i/a); | |
289 % if i<a | |
290 % lambda(i)=1-0.001*(1-(i/a)); | |
291 % else | |
292 % lambda(i)=1; | |
293 % end | |
294 % param.lambda=thresh; | |
295 % param.mode=2; | |
296 % param.L=32; | |
297 % w=mexLasso(data(:,i), D, param); | |
298 spind=find(w); | |
299 %replaced(spind)=replaced(spind)+1; | |
300 %-0.001*(1/2)^(i/a); | |
301 % w_sp(i)=nnz(w); | |
302 residual = data(:,i) - D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)* w; | |
303 %if ~isempty(spind) | |
304 %i | |
305 if (j==1) | |
306 C = C *(1/ lambda); | |
307 end | |
308 u = C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize)+spind) * w(spind); | |
309 | |
310 %spindu=find(u); | |
311 % v = D' * residual; | |
312 | |
313 alfa = 1/(1 + w' * u); | |
314 | |
315 D(:,((idx(i)-1)*dictsize+1):idx(i)*dictsize) = D (:,((idx(i)-1)*dictsize+1):idx(i)*dictsize)+ (alfa * residual) * u'; | |
316 | |
317 %uut=; | |
318 C (((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize)= C(((idx(i)-1)*dictsize+1):idx(i)*dictsize,((idx(i)-1)*dictsize+1):idx(i)*dictsize) - (alfa * u)* u'; | |
319 % lambda=(19*lambda+1)/20; | |
320 % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut; | |
321 | |
322 % if (mod(i,modi)==0) | |
323 % Ximd=zeros(size(X)); | |
324 % Ximd(:,p1((i-modi+1:i)))=data(:,i-modi+1:i); | |
325 % | |
326 % if (iternum==4) | |
327 % X_ima(:,:,1)=col2imstep(Ximd, [256 256], [8 8]); | |
328 % X_ima(:,:,2)=col2imstep(X, [256 256], [8 8]); | |
329 % X_ima(:,:,3)=zeros(256,256); | |
330 % else | |
331 % X_ima(:,:,1)=col2imstep(Ximd, [512 512], [8 8]); | |
332 % X_ima(:,:,2)=col2imstep(X, [512 512], [8 8]); | |
333 % X_ima(:,:,3)=zeros(512,512); | |
334 % end | |
335 % | |
336 % dictimg1=dictimg; | |
337 % dictimg = showdict(D,[8 8],... | |
338 % round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast'); | |
339 % dictimg1=(dictimg-dictimg1); | |
340 % | |
341 % figure(2); | |
342 % subplot(2,2,1); imshow(X_ima); title(sprintf('%d',i)); | |
343 % subplot(2,2,3); imshow(imresize(dictimg,2,'nearest')); | |
344 % subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest')); | |
345 % subplot(2,2,2);imshow(C*(255/max(max(C)))); | |
346 % pause(0.02); | |
347 % if (i>=35000) | |
348 % modi=100; | |
349 % pause | |
350 % end; | |
351 % end | |
352 % end | |
353 end | |
354 %p1=p1(setxor(p2,1:end)); | |
355 %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced); | |
356 %replaced=zeros(dictsize,1); | |
357 % W=sparsecode(data, D, [], [], thresh); | |
358 % data=D*W; | |
359 lambda=lambda+0.0002 | |
360 end | |
361 %Gamma=mexLasso(data, D, param); | |
362 %err=compute_err(D,Gamma, data); | |
363 %[y,i]=max(err); | |
364 %D(:,1)=data(:,i)/norm(data(:,i)); | |
365 D=normcols(D); | |
366 D_norm=sqrt(sum(D.^2, 1)); | |
367 D_norm_1=sum(abs(D)); | |
368 % X_norm_1=sum(abs(X)); | |
369 % X_norm_inf=max(abs(X)); | |
370 [D_norm_sort, p]=sort(D_norm_1, 'descend'); | |
371 Dictionary = D;%D(:,p); | |
372 % figure(3); | |
373 % plot(lambda); | |
374 % mean(lambda); | |
375 % figure(4+j);plot(w_sp); | |
376 end | |
377 | |
378 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
379 % sparsecode % | |
380 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
381 | |
382 function Gamma = sparsecode(data,D,XtX,G,thresh) | |
383 | |
384 global CODE_SPARSITY codemode | |
385 global MEM_HIGH memusage | |
386 global ompfunc ompparams | |
387 | |
388 if (memusage < MEM_HIGH) | |
389 Gamma = ompfunc(D,data,G,thresh,ompparams{:}); | |
390 | |
391 else % memusage is high | |
392 | |
393 if (codemode == CODE_SPARSITY) | |
394 Gamma = ompfunc(D'*data,G,thresh,ompparams{:}); | |
395 | |
396 else | |
397 Gamma = ompfunc(D, data, G, thresh,ompparams{:}); | |
398 end | |
399 | |
400 end | |
401 | |
402 end | |
403 | |
404 | |
405 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
406 % compute_err % | |
407 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
408 | |
409 | |
410 function err = compute_err(D,Gamma,data) | |
411 | |
412 global CODE_SPARSITY codemode | |
413 | |
414 if (codemode == CODE_SPARSITY) | |
415 err = sqrt(sum(reperror2(data,D,Gamma))/numel(data)); | |
416 else | |
417 err = nnz(Gamma)/size(data,2); | |
418 end | |
419 | |
420 end | |
421 | |
422 | |
423 | |
424 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
425 % cleardict % | |
426 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
427 | |
428 | |
429 function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms) | |
430 | |
431 use_thresh = 4; % at least this number of samples must use the atom to be kept | |
432 | |
433 dictsize = size(D,2); | |
434 | |
435 % compute error in blocks to conserve memory | |
436 % err = zeros(1,size(X,2)); | |
437 % blocks = [1:3000:size(X,2) size(X,2)+1]; | |
438 % for i = 1:length(blocks)-1 | |
439 % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2); | |
440 % end | |
441 | |
442 cleared_atoms = 0; | |
443 usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2); | |
444 | |
445 for j = 1:dictsize | |
446 | |
447 % compute G(:,j) | |
448 Gj = D'*D(:,j); | |
449 Gj(j) = 0; | |
450 | |
451 % replace atom | |
452 if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) ) | |
453 % [y,i] = max(err(unused_sigs)); | |
454 D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end))); | |
455 unused_sigs = unused_sigs([1:end-1]); | |
456 cleared_atoms = cleared_atoms+1; | |
457 end | |
458 end | |
459 | |
460 end | |
461 | |
462 | |
463 | |
464 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
465 % misc functions % | |
466 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
467 | |
468 | |
469 function err2 = reperror2(X,D,Gamma) | |
470 | |
471 % compute in blocks to conserve memory | |
472 err2 = zeros(1,size(X,2)); | |
473 blocksize = 2000; | |
474 for i = 1:blocksize:size(X,2) | |
475 blockids = i : min(i+blocksize-1,size(X,2)); | |
476 err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2); | |
477 end | |
478 | |
479 end | |
480 | |
481 | |
482 function Y = colnorms_squared(X) | |
483 | |
484 % compute in blocks to conserve memory | |
485 Y = zeros(1,size(X,2)); | |
486 blocksize = 2000; | |
487 for i = 1:blocksize:size(X,2) | |
488 blockids = i : min(i+blocksize-1,size(X,2)); | |
489 Y(blockids) = sum(X(:,blockids).^2); | |
490 end | |
491 | |
492 end | |
493 | |
494 |