ivan@152
|
1 function DL=SMALL_two_step_DL(Problem, DL)
|
bmailhe@210
|
2
|
bmailhe@210
|
3 %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL
|
bmailhe@210
|
4 % The specific parameters of the DL structure are:
|
bmailhe@210
|
5 % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'.
|
bmailhe@210
|
6 % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1
|
bmailhe@210
|
7 % for 'ols', 1 for 'opt'.
|
bmailhe@210
|
8 % -param.flow: can be either 'sequential' or 'parallel'. De fault:
|
bmailhe@210
|
9 % 'sequential'. Not used by MOD.
|
bmailhe@210
|
10 % -param.coherence: a real number between 0 and 1. If present, then
|
bmailhe@210
|
11 % a low-coherence constraint is added to the learning.
|
bmailhe@210
|
12 %
|
bmailhe@210
|
13 % See dico_update.m for more details.
|
ivan@152
|
14
|
ivan@152
|
15 % determine which solver is used for sparse representation %
|
ivan@152
|
16
|
ivan@152
|
17 solver = DL.param.solver;
|
ivan@152
|
18
|
bmailhe@209
|
19 % determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') %
|
ivan@152
|
20
|
ivan@152
|
21 typeUpdate = DL.name;
|
ivan@152
|
22
|
ivan@152
|
23 sig = Problem.b;
|
ivan@152
|
24
|
ivan@152
|
25 % determine dictionary size %
|
ivan@152
|
26
|
ivan@152
|
27 if (isfield(DL.param,'initdict'))
|
ivan@152
|
28 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@152
|
29 dictsize = length(DL.param.initdict);
|
ivan@152
|
30 else
|
ivan@152
|
31 dictsize = size(DL.param.initdict,2);
|
ivan@152
|
32 end
|
ivan@152
|
33 end
|
ivan@152
|
34 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
|
ivan@152
|
35 dictsize = DL.param.dictsize;
|
ivan@152
|
36 end
|
ivan@152
|
37
|
ivan@152
|
38 if (size(sig,2) < dictsize)
|
ivan@152
|
39 error('Number of training signals is smaller than number of atoms to train');
|
ivan@152
|
40 end
|
ivan@152
|
41
|
ivan@152
|
42
|
ivan@152
|
43 % initialize the dictionary %
|
ivan@152
|
44
|
luis@198
|
45 if (isfield(DL.param,'initdict'))
|
ivan@152
|
46 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
|
ivan@152
|
47 dico = sig(:,DL.param.initdict(1:dictsize));
|
ivan@152
|
48 else
|
ivan@152
|
49 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
|
ivan@152
|
50 error('Invalid initial dictionary');
|
ivan@152
|
51 end
|
ivan@152
|
52 dico = DL.param.initdict(:,1:dictsize);
|
ivan@152
|
53 end
|
ivan@152
|
54 else
|
ivan@152
|
55 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
|
ivan@152
|
56 perm = randperm(length(data_ids));
|
ivan@152
|
57 dico = sig(:,data_ids(perm(1:dictsize)));
|
ivan@152
|
58 end
|
ivan@152
|
59
|
ivan@152
|
60 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated
|
ivan@152
|
61 % after each atom update. If parallel, the residual is only updated once
|
ivan@152
|
62 % the whole dictionary has been computed. Sequential works better, there
|
ivan@152
|
63 % may be no need to implement parallel. Not used with MOD.
|
ivan@152
|
64
|
ivan@152
|
65 if isfield(DL.param,'flow')
|
ivan@152
|
66 flow = DL.param.flow;
|
ivan@152
|
67 else
|
ivan@152
|
68 flow = 'sequential';
|
ivan@152
|
69 end
|
ivan@152
|
70
|
ivan@152
|
71 % learningRate. If the type is 'ols', it is the descent step of
|
bmailhe@210
|
72 % the gradient (default: 0.1). If the type is 'mailhe', the
|
bmailhe@210
|
73 % descent step is the optimal step*rho (default: 1, although 2 works
|
bmailhe@209
|
74 % better). Not used for MOD and KSVD.
|
ivan@152
|
75
|
ivan@152
|
76 if isfield(DL.param,'learningRate')
|
ivan@152
|
77 learningRate = DL.param.learningRate;
|
ivan@152
|
78 else
|
bmailhe@209
|
79 switch typeUpdate
|
bmailhe@209
|
80 case 'ols'
|
bmailhe@209
|
81 learningRate = 0.1;
|
bmailhe@209
|
82 otherwise
|
bmailhe@209
|
83 learningRate = 1;
|
bmailhe@209
|
84 end
|
ivan@152
|
85 end
|
ivan@152
|
86
|
ivan@152
|
87 % number of iterations (default is 40) %
|
ivan@152
|
88
|
ivan@152
|
89 if isfield(DL.param,'iternum')
|
ivan@152
|
90 iternum = DL.param.iternum;
|
ivan@152
|
91 else
|
ivan@152
|
92 iternum = 40;
|
ivan@152
|
93 end
|
ivan@152
|
94 % determine if we should do decorrelation in every iteration %
|
ivan@152
|
95
|
ivan@152
|
96 if isfield(DL.param,'coherence')
|
ivan@152
|
97 decorrelate = 1;
|
ivan@152
|
98 mu = DL.param.coherence;
|
ivan@152
|
99 else
|
ivan@152
|
100 decorrelate = 0;
|
ivan@152
|
101 end
|
ivan@152
|
102
|
ivan@152
|
103 % show dictonary every specified number of iterations
|
ivan@152
|
104
|
ivan@153
|
105 if isfield(DL.param,'show_dict')
|
ivan@152
|
106 show_dictionary=1;
|
ivan@152
|
107 show_iter=DL.param.show_dict;
|
ivan@152
|
108 else
|
ivan@152
|
109 show_dictionary=0;
|
ivan@152
|
110 show_iter=0;
|
ivan@152
|
111 end
|
ivan@152
|
112
|
ivan@152
|
113 % This is a small patch that needs to be resolved in dictionary learning we
|
ivan@152
|
114 % want sparse representation of training set, and in Problem.b1 in this
|
ivan@152
|
115 % version of software we store the signal that needs to be represented
|
ivan@152
|
116 % (for example the whole image)
|
ivan@152
|
117
|
ivan@152
|
118 tmpTraining = Problem.b1;
|
ivan@152
|
119 Problem.b1 = sig;
|
ivan@153
|
120 if isfield(Problem,'reconstruct')
|
ivan@153
|
121 Problem = rmfield(Problem, 'reconstruct');
|
ivan@153
|
122 end
|
ivan@152
|
123 solver.profile = 0;
|
ivan@152
|
124
|
ivan@152
|
125 % main loop %
|
ivan@152
|
126
|
ivan@152
|
127 for i = 1:iternum
|
ivan@153
|
128 Problem.A = dico;
|
ivan@152
|
129 solver = SMALL_solve(Problem, solver);
|
luis@198
|
130 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
|
luis@198
|
131 typeUpdate, flow, learningRate);
|
luis@198
|
132 if (decorrelate)
|
bmailhe@209
|
133 dico = dico_decorr_symetric(dico, mu, solver.solution);
|
luis@198
|
134 end
|
ivan@153
|
135
|
ivan@152
|
136 if ((show_dictionary)&&(mod(i,show_iter)==0))
|
ivan@152
|
137 dictimg = SMALL_showdict(dico,[8 8],...
|
ivan@152
|
138 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
|
ivan@152
|
139 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
|
ivan@152
|
140 pause(0.02);
|
ivan@152
|
141 end
|
ivan@152
|
142 end
|
ivan@152
|
143
|
ivan@152
|
144 Problem.b1 = tmpTraining;
|
ivan@152
|
145 DL.D = dico;
|
ivan@152
|
146
|
ivan@153
|
147 end
|
ivan@153
|
148
|
ivan@153
|
149 function Y = colnorms_squared(X)
|
ivan@153
|
150
|
ivan@153
|
151 % compute in blocks to conserve memory
|
ivan@153
|
152 Y = zeros(1,size(X,2));
|
ivan@153
|
153 blocksize = 2000;
|
ivan@153
|
154 for i = 1:blocksize:size(X,2)
|
ivan@153
|
155 blockids = i : min(i+blocksize-1,size(X,2));
|
ivan@153
|
156 Y(blockids) = sum(X(:,blockids).^2);
|
ivan@153
|
157 end
|
ivan@153
|
158
|
luis@198
|
159 end |