comparison DL/two-step DL/SMALL_two_step_DL.m @ 224:fd0b5d36f6ad danieleb

Updated the contents of this branch with the contents of the default branch.
author luisf <luis.figueira@eecs.qmul.ac.uk>
date Thu, 12 Apr 2012 13:52:28 +0100
parents d0645d5fca7d f12a476a4977
children
comparison
equal deleted inserted replaced
196:82b0d3f982cb 224:fd0b5d36f6ad
1 function DL=SMALL_two_step_DL(Problem, DL) 1 function DL=SMALL_two_step_DL(Problem, DL)
2
3 %% DL=SMALL_two_step_DL(Problem, DL) learn a dictionary using two_step_DL
4 % The specific parameters of the DL structure are:
5 % -name: can be either 'ols', 'opt', 'MOD', KSVD' or 'LGD'.
6 % -param.learningRate: a step size used by 'ols' and 'opt'. Default: 0.1
7 % for 'ols', 1 for 'opt'.
8 % -param.flow: can be either 'sequential' or 'parallel'. De fault:
9 % 'sequential'. Not used by MOD.
10 % -param.coherence: a real number between 0 and 1. If present, then
11 % a low-coherence constraint is added to the learning.
12 %
13 % See dico_update.m for more details.
2 14
3 % determine which solver is used for sparse representation % 15 % determine which solver is used for sparse representation %
4 16
5 solver = DL.param.solver; 17 solver = DL.param.solver;
6 18
7 % determine which type of udate to use ('KSVD', 'MOD','MOCOD','ols' or 'mailhe') % 19 % determine which type of udate to use ('KSVD', 'MOD', 'ols', 'opt' or 'LGD') %
8 20
9 typeUpdate = DL.name; 21 typeUpdate = DL.name;
10 22
11 sig = Problem.b; 23 sig = Problem.b;
12 24
28 end 40 end
29 41
30 42
31 % initialize the dictionary % 43 % initialize the dictionary %
32 44
33 if (isfield(DL.param,'initdict')) && ~isempty(DL.param.initdict); 45 if (isfield(DL.param,'initdict'))
34 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:)))) 46 if (any(size(DL.param.initdict)==1) && all(iswhole(DL.param.initdict(:))))
35 dico = sig(:,DL.param.initdict(1:dictsize)); 47 dico = sig(:,DL.param.initdict(1:dictsize));
36 else 48 else
37 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize) 49 if (size(DL.param.initdict,1)~=size(sig,1) || size(DL.param.initdict,2)<dictsize)
38 error('Invalid initial dictionary'); 50 error('Invalid initial dictionary');
55 else 67 else
56 flow = 'sequential'; 68 flow = 'sequential';
57 end 69 end
58 70
59 % learningRate. If the type is 'ols', it is the descent step of 71 % learningRate. If the type is 'ols', it is the descent step of
60 % the gradient (typical choice: 0.1). If the type is 'mailhe', the 72 % the gradient (default: 0.1). If the type is 'mailhe', the
61 % descent step is the optimal step*rho (typical choice: 1, although 2 73 % descent step is the optimal step*rho (default: 1, although 2 works
62 % or 3 seems to work better). Not used for MOD and KSVD. 74 % better). Not used for MOD and KSVD.
63 75
64 if isfield(DL.param,'learningRate') 76 if isfield(DL.param,'learningRate')
65 learningRate = DL.param.learningRate; 77 learningRate = DL.param.learningRate;
66 else 78 else
67 learningRate = 0.1; 79 switch typeUpdate
80 case 'ols'
81 learningRate = 0.1;
82 otherwise
83 learningRate = 1;
84 end
68 end 85 end
69 86
70 % number of iterations (default is 40) % 87 % number of iterations (default is 40) %
71 88
72 if isfield(DL.param,'iternum') 89 if isfield(DL.param,'iternum')
74 else 91 else
75 iternum = 40; 92 iternum = 40;
76 end 93 end
77 % determine if we should do decorrelation in every iteration % 94 % determine if we should do decorrelation in every iteration %
78 95
79 if isfield(DL.param,'coherence') && isscalar(DL.param.coherence) 96 if isfield(DL.param,'coherence')
80 decorrelate = 1; 97 decorrelate = 1;
81 mu = DL.param.coherence; 98 mu = DL.param.coherence;
82 else 99 else
83 decorrelate = 0; 100 decorrelate = 0;
84 end 101 end
85
86 if ~isfield(DL.param,'decFcn'), DL.param.decFcn = 'none'; end
87 102
88 % show dictonary every specified number of iterations 103 % show dictonary every specified number of iterations
89 104
90 if isfield(DL.param,'show_dict') 105 if isfield(DL.param,'show_dict')
91 show_dictionary=1; 106 show_dictionary=1;
108 solver.profile = 0; 123 solver.profile = 0;
109 124
110 % main loop % 125 % main loop %
111 126
112 for i = 1:iternum 127 for i = 1:iternum
113 %disp([num2str(i) '/' num2str(iternum)]); 128 Problem.A = dico;
114 %SPARSE CODING STEP
115 Problem.A = dico;
116 solver = SMALL_solve(Problem, solver); 129 solver = SMALL_solve(Problem, solver);
117 %DICTIONARY UPDATE STEP 130 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
118 if strcmpi(typeUpdate,'mocod') %if update is MOCOD create parameters structure 131 typeUpdate, flow, learningRate);
119 mocodParams = struct('zeta',DL.param.zeta,... %coherence regularization factor 132 if (decorrelate)
120 'eta',DL.param.eta,... %atoms norm regularization factor 133 dico = dico_decorr_symetric(dico, mu, solver.solution);
121 'Dprev',dico); %previous dictionary 134 end
122 dico = dico_update(dico,sig,solver.solution,typeUpdate,flow,learningRate,mocodParams);
123 else
124 [dico, solver.solution] = dico_update(dico, sig, solver.solution, ...
125 typeUpdate, flow, learningRate);
126 dico = normcols(dico);
127 end
128
129 switch lower(DL.param.decFcn)
130 case 'ink-svd'
131 dico = dico_decorr_symetric(dico,mu,solver.solution);
132 case 'grassmannian'
133 [n m] = size(dico);
134 dico = grassmannian(n,m,[],0.9,0.99,dico);
135 case 'shrinkgram'
136 dico = shrinkgram(dico,mu);
137 case 'iterproj'
138 dico = iterativeprojections(dico,mu,Problem.b1,solver.solution);
139 otherwise
140 end
141 135
142 if ((show_dictionary)&&(mod(i,show_iter)==0)) 136 if ((show_dictionary)&&(mod(i,show_iter)==0))
143 dictimg = SMALL_showdict(dico,[8 8],... 137 dictimg = SMALL_showdict(dico,[8 8],...
144 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast'); 138 round(sqrt(size(dico,2))),round(sqrt(size(dico,2))),'lines','highcontrast');
145 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image; 139 figure(2); imagesc(dictimg);colormap(gray);axis off; axis image;