ivan@152
|
1 function DL=SMALL_two_step_DL(Problem, DL)
|
ivan@152
|
2
|
ivan@152
|
3 % determine which solver is used for sparse representation %
|
ivan@152
|
4
|
ivan@152
|
5 solver = DL.param.solver;
|
ivan@152
|
6
|
daniele@176
|
7 % determine which type of udate to use ('KSVD', 'MOD','MOCOD','ols' or 'mailhe') %
|
ivan@152
|
8
|
ivan@152
|
9 typeUpdate = DL.name;
|
ivan@152
|
10
|
ivan@152
|
11 sig = Problem.b;
|
ivan@152
|
12
|
ivan@152
|
13 % determine dictionary size %
|
ivan@152
|
14
|
ivan@152
|
15 if (isfield(DL.param,'initdict'))
|
ivan@152
|
16 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@152
|
17 dictsize = length(DL.param.initdict);
|
ivan@152
|
18 else
|
ivan@152
|
19 dictsize = size(DL.param.initdict,2);
|
ivan@152
|
20 end
|
ivan@152
|
21 end
|
ivan@152
|
22 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@152
|
23 dictsize = DL.param.dictsize;
|
ivan@152
|
24 end
|
ivan@152
|
25
|
ivan@152
|
26 if (size(sig,2) < dictsize)
|
ivan@152
|
27 error('Number of training signals is smaller than number of atoms to train');
|
ivan@152
|
28 end
|
ivan@152
|
29
|
ivan@152
|
30
|
ivan@152
|
31 % initialize the dictionary %
|
ivan@152
|
32
|
daniele@169
|
33 if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict);
|
ivan@152
|
34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@152
|
35 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@152
|
36 else
|
ivan@152
|
37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@152
|
38 error('Invalid initial dictionary');
|
ivan@152
|
39 end
|
ivan@152
|
40 dico = DL.param.initdict(:,1:dictsize);
|
ivan@152
|
41 end
|
ivan@152
|
42 else
|
ivan@152
|
43 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@152
|
44 perm = randperm(length(data_ids));
|
ivan@152
|
45 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@152
|
46 end
|
ivan@152
|
47
|
ivan@152
|
48 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated
|
ivan@152
|
49 % after each atom update. If parallel, the residual is only updated once
|
ivan@152
|
50 % the whole dictionary has been computed. Sequential works better, there
|
ivan@152
|
51 % may be no need to implement parallel. Not used with MOD.
|
ivan@152
|
52
|
ivan@152
|
53 if isfield(DL.param,'flow')
|
ivan@152
|
54 flow = DL.param.flow;
|
ivan@152
|
55 else
|
ivan@152
|
56 flow = 'sequential';
|
ivan@152
|
57 end
|
ivan@152
|
58
|
ivan@152
|
59 % learningRate. If the type is 'ols', it is the descent step of
|
ivan@152
|
60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
|
ivan@152
|
61 % descent step is the optimal step*rho (typical choice: 1, although 2
|
ivan@152
|
62 % or 3 seems to work better). Not used for MOD and KSVD.
|
ivan@152
|
63
|
ivan@152
|
64 if isfield(DL.param,'learningRate')
|
ivan@152
|
65 learningRate = DL.param.learningRate;
|
ivan@152
|
66 else
|
ivan@152
|
67 learningRate = 0.1;
|
ivan@152
|
68 end
|
ivan@152
|
69
|
ivan@152
|
70 % number of iterations (default is 40) %
|
ivan@152
|
71
|
ivan@152
|
72 if isfield(DL.param,'iternum')
|
ivan@152
|
73 iternum = DL.param.iternum;
|
ivan@152
|
74 else
|
ivan@152
|
75 iternum = 40;
|
ivan@152
|
76 end
|
ivan@152
|
77 % determine if we should do decorrelation in every iteration %
|
ivan@152
|
78
|
danieleb@156
|
79 if isfield(DL.param,'coherence') && isscalar(DL.param.coherence)
|
ivan@152
|
80 decorrelate = 1;
|
ivan@152
|
81 mu = DL.param.coherence;
|
ivan@152
|
82 else
|
ivan@152
|
83 decorrelate = 0;
|
ivan@152
|
84 end
|
ivan@152
|
85
|
daniele@176
|
86 if ~isfield(DL.param,'decFcn'), DL.param.decFcn = 'none'; end
|
daniele@176
|
87
|
ivan@152
|
88 % show dictonary every specified number of iterations
|
ivan@152
|
89
|
ivan@153
|
90 if isfield(DL.param,'show_dict')
|
ivan@152
|
91 show_dictionary=1;
|
ivan@152
|
92 show_iter=DL.param.show_dict;
|
ivan@152
|
93 else
|
ivan@152
|
94 show_dictionary=0;
|
ivan@152
|
95 show_iter=0;
|
ivan@152
|
96 end
|
ivan@152
|
97
|
ivan@152
|
98 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@152
|
99 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@152
|
100 % version of software we store the signal that needs to be represented
|
ivan@152
|
101 % (for example the whole image)
|
ivan@152
|
102
|
ivan@152
|
103 tmpTraining = Problem.b1;
|
ivan@152
|
104 Problem.b1 = sig;
|
ivan@153
|
105 if isfield(Problem,'reconstruct')
|
ivan@153
|
106 Problem = rmfield(Problem, 'reconstruct');
|
ivan@153
|
107 end
|
ivan@152
|
108 solver.profile = 0;
|
ivan@152
|
109
|
ivan@152
|
110 % main loop %
|
ivan@152
|
111
|
ivan@152
|
112 for i = 1:iternum
|
daniele@169
|
113 %disp([num2str(i) '/' num2str(iternum)]);
|
daniele@176
|
114 %SPARSE CODING STEP
|
daniele@176
|
115 Problem.A = dico;
|
ivan@152
|
116 solver = SMALL_solve(Problem, solver);
|
daniele@176
|
117 %DICTIONARY UPDATE STEP
|
daniele@176
|
118 if strcmpi(typeUpdate,'mocod') %if update is MOCOD create parameters structure
|
daniele@176
|
119 mocodParams = struct('zeta',DL.param.zeta,... %coherence regularization factor
|
daniele@176
|
120 'eta',DL.param.eta,... %atoms norm regularization factor
|
daniele@176
|
121 'Dprev',dico); %previous dictionary
|
daniele@176
|
122 dico = dico_update(dico,sig,solver.solution,typeUpdate,flow,learningRate,mocodParams);
|
daniele@176
|
123 else
|
daniele@176
|
124 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
|
daniele@176
|
125 typeUpdate, flow, learningRate);
|
daniele@176
|
126 dico = normcols(dico);
|
daniele@176
|
127 end
|
daniele@176
|
128
|
daniele@170
|
129 switch lower(DL.param.decFcn)
|
daniele@170
|
130 case 'ink-svd'
|
daniele@170
|
131 dico = dico_decorr_symetric(dico,mu,solver.solution);
|
daniele@170
|
132 case 'grassmannian'
|
danieleb@156
|
133 [n m] = size(dico);
|
daniele@170
|
134 dico = grassmannian(n,m,[],0.9,0.99,dico);
|
daniele@170
|
135 case 'shrinkgram'
|
daniele@170
|
136 dico = shrinkgram(dico,mu);
|
daniele@170
|
137 case 'iterproj'
|
daniele@170
|
138 dico = iterativeprojections(dico,mu,Problem.b1,solver.solution);
|
danieleb@156
|
139 otherwise
|
danieleb@156
|
140 end
|
ivan@153
|
141
|
ivan@152
|
142 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@152
|
143 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@152
|
144 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@152
|
145 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@152
|
146 pause(0.02);
|
ivan@152
|
147 end
|
ivan@152
|
148 end
|
ivan@152
|
149
|
ivan@152
|
150 Problem.b1 = tmpTraining;
|
ivan@152
|
151 DL.D = dico;
|
ivan@152
|
152
|
ivan@153
|
153 end
|
ivan@153
|
154
|
ivan@153
|
155 function Y = colnorms_squared(X)
|
ivan@153
|
156
|
ivan@153
|
157 % compute in blocks to conserve memory
|
ivan@153
|
158 Y = zeros(1,size(X,2));
|
ivan@153
|
159 blocksize = 2000;
|
ivan@153
|
160 for i = 1:blocksize:size(X,2)
|
ivan@153
|
161 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@153
|
162 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@153
|
163 end
|
ivan@153
|
164
|
danieleb@156
|
165 end
|