comparison DL/RLS-DLA/SMALL_rlsdla1.m @ 40:6416fc12f2b8

(none)
author idamnjanovic
date Mon, 14 Mar 2011 15:35:24 +0000
parents
children
comparison
equal deleted inserted replaced
39:8f734534839a 40:6416fc12f2b8
1 function Dictionary = SMALL_rlsdla1(X, params)
2
3
4
5
6
7 global CODE_SPARSITY CODE_ERROR codemode
8 global MEM_LOW MEM_NORMAL MEM_HIGH memusage
9 global ompfunc ompparams exactsvd
10
11 CODE_SPARSITY = 1;
12 CODE_ERROR = 2;
13
14 MEM_LOW = 1;
15 MEM_NORMAL = 2;
16 MEM_HIGH = 3;
17
18
19 % p = randperm(size(X,2));
20
21 % coding mode %
22 X_norm=sqrt(sum(X.^2, 1));
23 % X_norm_1=sum(abs(X));
24 % X_norm_inf=max(abs(X));
25 [X_norm_sort, p]=sort(X_norm);%, 'descend');
26 % [X_norm_sort1, p5]=sort(X_norm_1);%, 'descend');
27
28 % if (isfield(params,'codemode'))
29 % switch lower(params.codemode)
30 % case 'sparsity'
31 % codemode = CODE_SPARSITY;
32 % thresh = params.Tdata;
33 % case 'error'
34 % codemode = CODE_ERROR;
35 % thresh = params.Edata;
36 % otherwise
37 % error('Invalid coding mode specified');
38 % end
39 % elseif (isfield(params,'Tdata'))
40 % codemode = CODE_SPARSITY;
41 % thresh = params.Tdata;
42 % elseif (isfield(params,'Edata'))
43 % codemode = CODE_ERROR;
44 % thresh = params.Edata;
45 %
46 % else
47 % error('Data sparse-coding target not specified');
48 % end
49
50 thresh = params.Edata;
51 % max number of atoms %
52
53 % if (codemode==CODE_ERROR && isfield(params,'maxatoms'))
54 % ompparams{end+1} = 'maxatoms';
55 % ompparams{end+1} = params.maxatoms;
56 % end
57
58
59 % memory usage %
60
61 if (isfield(params,'memusage'))
62 switch lower(params.memusage)
63 case 'low'
64 memusage = MEM_LOW;
65 case 'normal'
66 memusage = MEM_NORMAL;
67 case 'high'
68 memusage = MEM_HIGH;
69 otherwise
70 error('Invalid memory usage mode');
71 end
72 else
73 memusage = MEM_NORMAL;
74 end
75
76
77 % iteration count %
78
79 if (isfield(params,'iternum'))
80 iternum = params.iternum;
81 else
82 iternum = 10;
83 end
84
85
86 % omp function %
87
88 if (codemode == CODE_SPARSITY)
89 ompfunc = @omp;
90 else
91 ompfunc = @omp2;
92 end
93
94
95 % % status messages %
96 %
97 % printiter = 0;
98 % printreplaced = 0;
99 % printerr = 0;
100 % printgerr = 0;
101 %
102 % verbose = 't';
103 % msgdelta = -1;
104 %
105
106 %
107 % for i = 1:length(verbose)
108 % switch lower(verbose(i))
109 % case 'i'
110 % printiter = 1;
111 % case 'r'
112 % printiter = 1;
113 % printreplaced = 1;
114 % case 't'
115 % printiter = 1;
116 % printerr = 1;
117 % if (isfield(params,'testdata'))
118 % printgerr = 1;
119 % end
120 % end
121 % end
122 %
123 % if (msgdelta<=0 || isempty(verbose))
124 % msgdelta = -1;
125 % end
126 %
127 % ompparams{end+1} = 'messages';
128 % ompparams{end+1} = msgdelta;
129 %
130 %
131 %
132 % % compute error flag %
133 %
134 % comperr = (nargout>=3 || printerr);
135 %
136 %
137 % % validation flag %
138 %
139 % testgen = 0;
140 % if (isfield(params,'testdata'))
141 % testdata = params.testdata;
142 % if (nargout>=4 || printgerr)
143 % testgen = 1;
144 % end
145 % end
146
147 %
148 % % data norms %
149 %
150 % XtX = []; XtXg = [];
151 % if (codemode==CODE_ERROR && memusage==MEM_HIGH)
152 % XtX = colnorms_squared(data);
153 % if (testgen)
154 % XtXg = colnorms_squared(testdata);
155 % end
156 % end
157
158
159 % mutual incoherence limit %
160
161 if (isfield(params,'muthresh'))
162 muthresh = params.muthresh;
163 else
164 muthresh = 0.99;
165 end
166 if (muthresh < 0)
167 error('invalid muthresh value, must be non-negative');
168 end
169
170
171
172
173
174 % determine dictionary size %
175
176 if (isfield(params,'initdict'))
177 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
178 dictsize = length(params.initdict);
179 else
180 dictsize = size(params.initdict,2);
181 end
182 end
183 if (isfield(params,'dictsize')) % this superceedes the size determined by initdict
184 dictsize = params.dictsize;
185 end
186
187 if (size(X,2) < dictsize)
188 error('Number of training signals is smaller than number of atoms to train');
189 end
190
191
192 % initialize the dictionary %
193
194 if (isfield(params,'initdict'))
195 if (any(size(params.initdict)==1) && all(iswhole(params.initdict(:))))
196 D1 = X(:,params.initdict(1:dictsize));
197 D2 = X(:,params.initdict(1:dictsize));
198 else
199 if (size(params.initdict,1)~=size(X,1) || size(params.initdict,2)<dictsize)
200 error('Invalid initial dictionary');
201 end
202 D1 = params.initdict(:,1:dictsize);
203 D2 = params.initdict(:,1:dictsize);
204 end
205 else
206 data_ids = find(colnorms_squared(X) > 1e-6); % ensure no zero data elements are chosen
207 perm = randperm(length(data_ids));
208 D = X(:,data_ids(perm(1:dictsize)));
209 end
210
211
212 % normalize the dictionary %
213
214 % D = normcols(D);
215 % DtD=D'*D;
216
217 err = zeros(1,iternum);
218 gerr = zeros(1,iternum);
219
220 if (codemode == CODE_SPARSITY)
221 errstr = 'RMSE';
222 else
223 errstr = 'mean atomnum';
224 end
225 X(:,p(X_norm_sort<thresh))=0;
226 % if (iternum==4)
227 % X_im=col2imstep(X, [256 256], [8 8]);
228 % else
229 % X_im=col2imstep(X, [512 512], [8 8]);
230 % end
231 % figure(10); imshow(X_im);
232 p1=p(X_norm_sort>thresh);
233
234 %p1=p1(p2(1:40000));
235 %end-min(40000, end)+1:end));%1:min(40000, end)));
236 %p1 = randperm(size(data,2));%size(data,2)
237 %data=data(:,p1);
238
239 C1=(100000*thresh)*eye(dictsize);
240 C2=(100000*thresh)*eye(dictsize);
241 % figure(11);
242 w=zeros(dictsize,1);
243 replaced=zeros(dictsize,1);
244 u=zeros(dictsize,1);
245 dictimg = showdict(D1,[8 8],round(sqrt(size(D1,2))),round(sqrt(size(D1,2))),'lines','highcontrast');
246 % h=imshow(imresize(dictimg,2,'nearest'));
247 lambda=0.9998
248 for j=1:3
249 if size(p1,2)>20000
250 p2 = randperm(floor(size(p1,2)/2));
251 p2=sort(p2(1:20000));
252 data1=X(:,p1(p2));
253 data2=X(:,p1(floor(size(p1,2)/2)+p2));
254 elseif size(p1,2)>0
255 data=X(:,p1);
256 else
257 break;
258 end
259 % figure(1);
260 % plot(sqrt(sum(data.^2, 1)));
261 % a=size(data,2)/4;
262 % lambda0=0.99;%1-16/numS+iternum*0.0001-0.0002
263 C1(1,1)=0;
264 C2(1,1)=0;
265 for i = 1:size(data1,2)
266 % if norm(data(:,i))>thresh
267 % par.multA= @(x,par) multMatr(D,x); % user function y=Ax
268 % par.multAt=@(x,par) multMatrAdj(D,x); % user function y=A'*x
269 % par.y=data(:,i);
270 % w=SolveFISTA(D,data(:,i),'lambda',0.5*thresh);
271 % w=sesoptn(zeros(dictsize,1),par.func_u, par.func_x, par.multA, par.multAt,options,par);
272 %w = SMALL_chol(D,data(:,i), 256,32, thresh);%
273 %w = sparsecode(data(:,i), D, [], [], thresh);
274 w1 = omp2mex(D1,data1(:,i),[],[],[],thresh,0,-1,-1,0);
275 w2 = omp2mex(D2,data2(:,i),[],[],[],thresh,0,-1,-1,0);
276 %w(find(w<1))=0;
277 %^2;
278 % lambda(i)=1-0.001/(1+i/a);
279 % if i<a
280 % lambda(i)=1-0.001*(1-(i/a));
281 % else
282 % lambda(i)=1;
283 % end
284 % param.lambda=thresh;
285 % param.mode=2;
286 % param.L=32;
287 % w=mexLasso(data(:,i), D, param);
288 spind1=find(w1);
289 spind2=find(w2);
290
291 %replaced(spind)=replaced(spind)+1;
292 %-0.001*(1/2)^(i/a);
293 % w_sp(i)=nnz(w);
294 residual1 = data1(:,i) - D1 * w1;
295 residual2 = data2(:,i) - D2 * w2;
296 %if ~isempty(spind)
297 %i
298
299 C1 = C1 *(1/ lambda);
300 C2 = C2 *(1/ lambda);
301 u1 = C1(:,spind1) * w1(spind1);
302 u2 = C2(:,spind2) * w2(spind2);
303 %spindu=find(u);
304 % v = D' * residual;
305
306 alfa1 = 1/(1 + w1' * u1);
307 alfa2 = 1/(1 + w2' * u2);
308 D1 = D1 + (alfa1 * residual1) * u1';
309 D2 = D2 + (alfa2 * residual2) * u2';
310 %uut=;
311 C1 = C1 - (alfa1 * u1)* u1';
312 C2 = C2 - (alfa2 * u2)* u2';
313 % lambda=(19*lambda+1)/20;
314 % DtD = DtD + alfa * ( v*u' + u*v') + alfa^2 * (residual'*residual) * uut;
315 % modi=5000;
316 % if (mod(i,modi)==0)
317 % Ximd=zeros(size(X));
318 % Ximd(:,p((i-modi+1:i)))=data(:,i-modi+1:i);
319 %
320 % if (iternum==4)
321 % X_ima=col2imstep(Ximd, [256 256], [8 8]);
322 % else
323 % X_ima=col2imstep(Ximd, [512 512], [8 8]);
324 % end
325 % dictimg1=dictimg;
326 % dictimg = showdict(D,[8 8],...
327 % round(sqrt(size(D,2))),round(sqrt(size(D,2))),'lines','highcontrast');
328 % dictimg1=(dictimg-dictimg1)*255;
329 %
330 % figure(2);
331 % subplot(2,2,1); imshow(X_ima);
332 % subplot(2,2,3); imshow(imresize(dictimg,2,'nearest'));
333 % subplot(2,2,4); imshow(imresize(dictimg1,2,'nearest'));
334 % subplot(2,2,2);imshow(C*(255/max(max(C))));
335 % pause(0.02);
336 % end
337 % end
338 end
339 %p1=p1(setxor(p2,1:end));
340 %[D,cleared_atoms] = cleardict(D,X,muthresh,p1,replaced);
341 %replaced=zeros(dictsize,1);
342 % W=sparsecode(data, D, [], [], thresh);
343 % data=D*W;
344 lambda=lambda+0.0001
345 end
346 %Gamma=mexLasso(data, D, param);
347 %err=compute_err(D,Gamma, data);
348 %[y,i]=max(err);
349 %D(:,1)=data(:,i)/norm(data(:,i));
350 % D=normcols(D);
351 % D_norm=sqrt(sum(D.^2, 1));
352 % D_norm_1=sum(abs(D));
353 % X_norm_1=sum(abs(X));
354 % X_norm_inf=max(abs(X));
355 % [D_norm_sort, p]=sort(D_norm_1, 'descend');
356 Dictionary =[D1 D2];
357 % figure(3);
358 % plot(lambda);
359 % mean(lambda);
360 % figure(4+j);plot(w_sp);
361 end
362
363 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
364 % sparsecode %
365 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
366
367 function Gamma = sparsecode(data,D,XtX,G,thresh)
368
369 global CODE_SPARSITY codemode
370 global MEM_HIGH memusage
371 global ompfunc ompparams
372
373 if (memusage < MEM_HIGH)
374 Gamma = ompfunc(D,data,G,thresh,ompparams{:});
375
376 else % memusage is high
377
378 if (codemode == CODE_SPARSITY)
379 Gamma = ompfunc(D'*data,G,thresh,ompparams{:});
380
381 else
382 Gamma = ompfunc(D, data, G, thresh,ompparams{:});
383 end
384
385 end
386
387 end
388
389
390 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
391 % compute_err %
392 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
393
394
395 function err = compute_err(D,Gamma,data)
396
397 global CODE_SPARSITY codemode
398
399 if (codemode == CODE_SPARSITY)
400 err = sqrt(sum(reperror2(data,D,Gamma))/numel(data));
401 else
402 err = nnz(Gamma)/size(data,2);
403 end
404
405 end
406
407
408
409 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
410 % cleardict %
411 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
412
413
414 function [D,cleared_atoms] = cleardict(D,X,muthresh,unused_sigs,replaced_atoms)
415
416 use_thresh = 4; % at least this number of samples must use the atom to be kept
417
418 dictsize = size(D,2);
419
420 % compute error in blocks to conserve memory
421 % err = zeros(1,size(X,2));
422 % blocks = [1:3000:size(X,2) size(X,2)+1];
423 % for i = 1:length(blocks)-1
424 % err(blocks(i):blocks(i+1)-1) = sum((X(:,blocks(i):blocks(i+1)-1)-D*Gamma(:,blocks(i):blocks(i+1)-1)).^2);
425 % end
426
427 cleared_atoms = 0;
428 usecount = replaced_atoms;%sum(abs(Gamma)>1e-7, 2);
429
430 for j = 1:dictsize
431
432 % compute G(:,j)
433 Gj = D'*D(:,j);
434 Gj(j) = 0;
435
436 % replace atom
437 if ( (max(Gj.^2)>muthresh^2 || usecount(j)<use_thresh) && ~replaced_atoms(j) )
438 % [y,i] = max(err(unused_sigs));
439 D(:,j) = X(:,unused_sigs(end)) / norm(X(:,unused_sigs(end)));
440 unused_sigs = unused_sigs([1:end-1]);
441 cleared_atoms = cleared_atoms+1;
442 end
443 end
444
445 end
446
447
448
449 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
450 % misc functions %
451 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
452
453
454 function err2 = reperror2(X,D,Gamma)
455
456 % compute in blocks to conserve memory
457 err2 = zeros(1,size(X,2));
458 blocksize = 2000;
459 for i = 1:blocksize:size(X,2)
460 blockids = i : min(i+blocksize-1,size(X,2));
461 err2(blockids) = sum((X(:,blockids) - D*Gamma(:,blockids)).^2);
462 end
463
464 end
465
466
467 function Y = colnorms_squared(X)
468
469 % compute in blocks to conserve memory
470 Y = zeros(1,size(X,2));
471 blocksize = 2000;
472 for i = 1:blocksize:size(X,2)
473 blockids = i : min(i+blocksize-1,size(X,2));
474 Y(blockids) = sum(X(:,blockids).^2);
475 end
476
477 end
478
479