comparison DL/two-step DL/SMALL_two_step_DL.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children af307f247ac7
comparison
equal deleted inserted replaced
149:fec205ec6ef6 152:485747bf39e0
1 function DL=SMALL_two_step_DL(Problem, DL)
2
3 % determine which solver is used for sparse representation %
4
5 solver = DL.param.solver;
6
7 % determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') %
8
9 typeUpdate = DL.name;
10
11 sig = Problem.b;
12
13 % determine dictionary size %
14
15 if (isfield(DL.param,'initdict'))
16 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
17 dictsize = length(DL.param.initdict);
18 else
19 dictsize = size(DL.param.initdict,2);
20 end
21 end
22 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict
23 dictsize = DL.param.dictsize;
24 end
25
26 if (size(sig,2) < dictsize)
27 error('Number of training signals is smaller than number of atoms to train');
28 end
29
30
31 % initialize the dictionary %
32
33 if (isfield(DL.param,'initdict'))
34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
35 dico = sig(:,DL.param.initdict(1:dictsize));
36 else
37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
38 error('Invalid initial dictionary');
39 end
40 dico = DL.param.initdict(:,1:dictsize);
41 end
42 else
43 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen
44 perm = randperm(length(data_ids));
45 dico = sig(:,data_ids(perm(1:dictsize)));
46 end
47
48 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated
49 % after each atom update. If parallel, the residual is only updated once
50 % the whole dictionary has been computed. Sequential works better, there
51 % may be no need to implement parallel. Not used with MOD.
52
53 if isfield(DL.param,'flow')
54 flow = DL.param.flow;
55 else
56 flow = 'sequential';
57 end
58
59 % learningRate. If the type is 'ols', it is the descent step of
60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
61 % descent step is the optimal step*rho (typical choice: 1, although 2
62 % or 3 seems to work better). Not used for MOD and KSVD.
63
64 if isfield(DL.param,'learningRate')
65 learningRate = DL.param.learningRate;
66 else
67 learningRate = 0.1;
68 end
69
70 % number of iterations (default is 40) %
71
72 if isfield(DL.param,'iternum')
73 iternum = DL.param.iternum;
74 else
75 iternum = 40;
76 end
77 % determine if we should do decorrelation in every iteration %
78
79 if isfield(DL.param,'coherence')
80 decorrelate = 1;
81 mu = DL.param.coherence;
82 else
83 decorrelate = 0;
84 end
85
86 % show dictonary every specified number of iterations
87
88 if (isfield(DL.param,'show_dict'))
89 show_dictionary=1;
90 show_iter=DL.param.show_dict;
91 else
92 show_dictionary=0;
93 show_iter=0;
94 end
95
96 % This is a small patch that needs to be resolved in dictionary learning we
97 % want sparse representation of training set, and in Problem.b1 in this
98 % version of software we store the signal that needs to be represented
99 % (for example the whole image)
100
101 tmpTraining = Problem.b1;
102 Problem.b1 = sig;
103 Problem = rmfield(Problem, 'reconstruct');
104 solver.profile = 0;
105
106 % main loop %
107
108 for i = 1:iternum
109 solver = SMALL_solve(Problem, solver);
110 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
111 typeUpdate, flow, learningRate);
112 if (decorrelate)
113 dico = dico_decorr(dico, mu, solver.solution);
114 end
115 Problem.A = dico;
116 if ((show_dictionary)&&(mod(i,show_iter)==0))
117 dictimg = SMALL_showdict(dico,[8 8],...
118 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
119 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;
120 pause(0.02);
121 end
122 end
123
124 Problem.b1 = tmpTraining;
125 DL.D = dico;
126
127 end