Mercurial > hg > smallbox
comparison DL/two-step DL/SMALL_two_step_DL.m @ 224:fd0b5d36f6ad danieleb
Updated the contents of this branch with the contents of the default branch.
author | luisf <luis.figueira@eecs.qmul.ac.uk> |
---|---|
date | Thu, 12 Apr 2012 13:52:28 +0100 |
parents | d0645d5fca7d f12a476a4977 |
children |
comparison
equal
deleted
inserted
replaced
196:82b0d3f982cb | 224:fd0b5d36f6ad |
---|---|
1 function DL=SMALL_two_step_DL(Problem, DL) | 1 function DL=SMALL_two_step_DL(Problem, DL) |
2 | |
3 %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL | |
4 % The specific parameters of the DL structure are: | |
5 % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'. | |
6 % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1 | |
7 % for 'ols', 1 for 'opt'. | |
8 % -param.flow: can be either 'sequential' or 'parallel'. De fault: | |
9 % 'sequential'. Not used by MOD. | |
10 % -param.coherence: a real number between 0 and 1. If present, then | |
11 % a low-coherence constraint is added to the learning. | |
12 % | |
13 % See dico_update.m for more details. | |
2 | 14 |
3 % determine which solver is used for sparse representation % | 15 % determine which solver is used for sparse representation % |
4 | 16 |
5 solver = DL.param.solver; | 17 solver = DL.param.solver; |
6 | 18 |
7 % determine which type of udate to use ('KSVD', 'MOD','MOCOD','ols' or 'mailhe') % | 19 % determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') % |
8 | 20 |
9 typeUpdate = DL.name; | 21 typeUpdate = DL.name; |
10 | 22 |
11 sig = Problem.b; | 23 sig = Problem.b; |
12 | 24 |
28 end | 40 end |
29 | 41 |
30 | 42 |
31 % initialize the dictionary % | 43 % initialize the dictionary % |
32 | 44 |
33 if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); | 45 if (isfield(DL.param,'initdict')) |
34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) | 46 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) |
35 dico = sig(:,DL.param.initdict(1:dictsize)); | 47 dico = sig(:,DL.param.initdict(1:dictsize)); |
36 else | 48 else |
37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) | 49 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) |
38 error('Invalid initial dictionary'); | 50 error('Invalid initial dictionary'); |
55 else | 67 else |
56 flow = 'sequential'; | 68 flow = 'sequential'; |
57 end | 69 end |
58 | 70 |
59 % learningRate. If the type is 'ols', it is the descent step of | 71 % learningRate. If the type is 'ols', it is the descent step of |
60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | 72 % the gradient (default: 0.1). If the type is 'mailhe', the |
61 % descent step is the optimal step*rho (typical choice: 1, although 2 | 73 % descent step is the optimal step*rho (default: 1, although 2 works |
62 % or 3 seems to work better). Not used for MOD and KSVD. | 74 % better). Not used for MOD and KSVD. |
63 | 75 |
64 if isfield(DL.param,'learningRate') | 76 if isfield(DL.param,'learningRate') |
65 learningRate = DL.param.learningRate; | 77 learningRate = DL.param.learningRate; |
66 else | 78 else |
67 learningRate = 0.1; | 79 switch typeUpdate |
80 case 'ols' | |
81 learningRate = 0.1; | |
82 otherwise | |
83 learningRate = 1; | |
84 end | |
68 end | 85 end |
69 | 86 |
70 % number of iterations (default is 40) % | 87 % number of iterations (default is 40) % |
71 | 88 |
72 if isfield(DL.param,'iternum') | 89 if isfield(DL.param,'iternum') |
74 else | 91 else |
75 iternum = 40; | 92 iternum = 40; |
76 end | 93 end |
77 % determine if we should do decorrelation in every iteration % | 94 % determine if we should do decorrelation in every iteration % |
78 | 95 |
79 if isfield(DL.param,'coherence') && isscalar(DL.param.coherence) | 96 if isfield(DL.param,'coherence') |
80 decorrelate = 1; | 97 decorrelate = 1; |
81 mu = DL.param.coherence; | 98 mu = DL.param.coherence; |
82 else | 99 else |
83 decorrelate = 0; | 100 decorrelate = 0; |
84 end | 101 end |
85 | |
86 if ~isfield(DL.param,'decFcn'), DL.param.decFcn = 'none'; end | |
87 | 102 |
88 % show dictonary every specified number of iterations | 103 % show dictonary every specified number of iterations |
89 | 104 |
90 if isfield(DL.param,'show_dict') | 105 if isfield(DL.param,'show_dict') |
91 show_dictionary=1; | 106 show_dictionary=1; |
108 solver.profile = 0; | 123 solver.profile = 0; |
109 | 124 |
110 % main loop % | 125 % main loop % |
111 | 126 |
112 for i = 1:iternum | 127 for i = 1:iternum |
113 %disp([num2str(i) '/' num2str(iternum)]); | 128 Problem.A = dico; |
114 %SPARSE CODING STEP | |
115 Problem.A = dico; | |
116 solver = SMALL_solve(Problem, solver); | 129 solver = SMALL_solve(Problem, solver); |
117 %DICTIONARY UPDATE STEP | 130 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... |
118 if strcmpi(typeUpdate,'mocod') %if update is MOCOD create parameters structure | 131 typeUpdate, flow, learningRate); |
119 mocodParams = struct('zeta',DL.param.zeta,... %coherence regularization factor | 132 if (decorrelate) |
120 'eta',DL.param.eta,... %atoms norm regularization factor | 133 dico = dico_decorr_symetric(dico, mu, solver.solution); |
121 'Dprev',dico); %previous dictionary | 134 end |
122 dico = dico_update(dico,sig,solver.solution,typeUpdate,flow,learningRate,mocodParams); | |
123 else | |
124 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ... | |
125 typeUpdate, flow, learningRate); | |
126 dico = normcols(dico); | |
127 end | |
128 | |
129 switch lower(DL.param.decFcn) | |
130 case 'ink-svd' | |
131 dico = dico_decorr_symetric(dico,mu,solver.solution); | |
132 case 'grassmannian' | |
133 [n m] = size(dico); | |
134 dico = grassmannian(n,m,[],0.9,0.99,dico); | |
135 case 'shrinkgram' | |
136 dico = shrinkgram(dico,mu); | |
137 case 'iterproj' | |
138 dico = iterativeprojections(dico,mu,Problem.b1,solver.solution); | |
139 otherwise | |
140 end | |
141 | 135 |
142 if ((show_dictionary)&&(mod(i,show_iter)==0)) | 136 if ((show_dictionary)&&(mod(i,show_iter)==0)) |
143 dictimg = SMALL_showdict(dico,[8 8],... | 137 dictimg = SMALL_showdict(dico,[8 8],... |
144 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); | 138 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); |
145 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; | 139 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; |