Mercurial > hg > smallbox
comparison DL/two-step DL/SMALL_two_step_DL.m @ 152:485747bf39e0 ivand_dev
Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Thu, 28 Jul 2011 15:49:32 +0100 |
parents | |
children | af307f247ac7 |
comparison
equal
deleted
inserted
replaced
149:fec205ec6ef6 | 152:485747bf39e0 |
---|---|
1 function DL=SMALL_two_step_DL(Problem, DL) | |
2 | |
3 % determine which solver is used for sparse representation % | |
4 | |
5 solver = DL.param.solver; | |
6 | |
7 % determine which type of udate to use ('KSVD', 'MOD', 'ols' or 'mailhe') % | |
8 | |
9 typeUpdate = DL.name; | |
10 | |
11 sig = Problem.b; | |
12 | |
13 % determine dictionary size % | |
14 | |
15 if (isfield(DL.param,'initdict')) | |
16 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
17 dictsize = length(DL.param.initdict); | |
18 else | |
19 dictsize = size(DL.param.initdict,2); | |
20 end | |
21 end | |
22 if (isfield(DL.param,'dictsize')) % this superceedes the size determined by initdict | |
23 dictsize = DL.param.dictsize; | |
24 end | |
25 | |
26 if (size(sig,2) < dictsize) | |
27 error('Number of training signals is smaller than number of atoms to train'); | |
28 end | |
29 | |
30 | |
31 % initialize the dictionary % | |
32 | |
33 if (isfield(DL.param,'initdict')) | |
34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | |
35 dico = sig(:,DL.param.initdict(1:dictsize)); | |
36 else | |
37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) | |
38 error('Invalid initial dictionary'); | |
39 end | |
40 dico = DL.param.initdict(:,1:dictsize); | |
41 end | |
42 else | |
43 data_ids = find(colnorms_squared(sig) > 1e-6); % ensure no zero data elements are chosen | |
44 perm = randperm(length(data_ids)); | |
45 dico = sig(:,data_ids(perm(1:dictsize))); | |
46 end | |
47 | |
48 % flow: 'sequential' or 'parallel'. If sequential, the residual is updated | |
49 % after each atom update. If parallel, the residual is only updated once | |
50 % the whole dictionary has been computed. Sequential works better, there | |
51 % may be no need to implement parallel. Not used with MOD. | |
52 | |
53 if isfield(DL.param,'flow') | |
54 flow = DL.param.flow; | |
55 else | |
56 flow = 'sequential'; | |
57 end | |
58 | |
59 % learningRate. If the type is 'ols', it is the descent step of | |
60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | |
61 % descent step is the optimal step*rho (typical choice: 1, although 2 | |
62 % or 3 seems to work better). Not used for MOD and KSVD. | |
63 | |
64 if isfield(DL.param,'learningRate') | |
65 learningRate = DL.param.learningRate; | |
66 else | |
67 learningRate = 0.1; | |
68 end | |
69 | |
70 % number of iterations (default is 40) % | |
71 | |
72 if isfield(DL.param,'iternum') | |
73 iternum = DL.param.iternum; | |
74 else | |
75 iternum = 40; | |
76 end | |
77 % determine if we should do decorrelation in every iteration % | |
78 | |
79 if isfield(DL.param,'coherence') | |
80 decorrelate = 1; | |
81 mu = DL.param.coherence; | |
82 else | |
83 decorrelate = 0; | |
84 end | |
85 | |
86 % show dictonary every specified number of iterations | |
87 | |
88 if (isfield(DL.param,'show_dict')) | |
89 show_dictionary=1; | |
90 show_iter=DL.param.show_dict; | |
91 else | |
92 show_dictionary=0; | |
93 show_iter=0; | |
94 end | |
95 | |
96 % This is a small patch that needs to be resolved in dictionary learning we | |
97 % want sparse representation of training set, and in Problem.b1 in this | |
98 % version of software we store the signal that needs to be represented | |
99 % (for example the whole image) | |
100 | |
101 tmpTraining = Problem.b1; | |
102 Problem.b1 = sig; | |
103 Problem = rmfield(Problem, 'reconstruct'); | |
104 solver.profile = 0; | |
105 | |
106 % main loop % | |
107 | |
108 for i = 1:iternum | |
109 solver = SMALL_solve(Problem, solver); | |
110 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... | |
111 typeUpdate, flow, learningRate); | |
112 if (decorrelate) | |
113 dico = dico_decorr(dico, mu, solver.solution); | |
114 end | |
115 Problem.A = dico; | |
116 if ((show_dictionary)&&(mod(i,show_iter)==0)) | |
117 dictimg = SMALL_showdict(dico,[8 8],... | |
118 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); | |
119 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; | |
120 pause(0.02); | |
121 end | |
122 end | |
123 | |
124 Problem.b1 = tmpTraining; | |
125 DL.D = dico; | |
126 | |
127 end |