Mercurial > hg > smallbox
comparison DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev
Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author | Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk> |
---|---|
date | Thu, 28 Jul 2011 15:49:32 +0100 |
parents | |
children | 9eb5f0d4c1a4 5140b0e06c22 |
comparison
equal
deleted
inserted
replaced
149:fec205ec6ef6 | 152:485747bf39e0 |
---|---|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | |
2 | |
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | |
5 % | |
6 % perform one iteration of dictionary update for dictionary learning | |
7 % | |
8 % parameters: | |
9 % - dico: the initial dictionary with atoms as columns | |
10 % - sig: the training data | |
11 % - amp: the amplitude coefficients as a sparse matrix | |
12 % - type: the algorithm can be one of the following | |
13 % - ols: fixed step gradient descent | |
14 % - mailhe: optimal step gradient descent (can be implemented as a | |
15 % default for ols?) | |
16 % - MOD: pseudo-inverse of the coefficients | |
17 % - KSVD: already implemented by Elad | |
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is | |
19 % updated after each atom update. If parallel, the residual is only | |
20 % updated once the whole dictionary has been computed. Sequential works | |
21 % better, there may be no need to implement parallel. Not used with | |
22 % MOD. | |
23 % - rho: learning rate. If the type is 'ols', it is the descent step of | |
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | |
25 % descent step is the optimal step*rho (typical choice: 1, although 2 | |
26 % or 3 seems to work better). Not used for MOD and KSVD. | |
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
28 if ~ exist( 'rho', 'var' ) || isempty(rho) | |
29 rho = 0.1; | |
30 end | |
31 | |
32 if ~ exist( 'flow', 'var' ) || isempty(flow) | |
33 flow = sequential; | |
34 end | |
35 | |
36 res = sig - dico*amp; | |
37 nb_pattern = size(dico, 2); | |
38 | |
39 switch type | |
40 case 'rand' | |
41 x = rand(); | |
42 if x < 1/3 | |
43 type = 'MOD'; | |
44 elseif type < 2/3 | |
45 type = 'mailhe'; | |
46 else | |
47 type = 'KSVD'; | |
48 end | |
49 end | |
50 | |
51 switch type | |
52 case 'MOD' | |
53 G = amp*amp'; | |
54 dico2 = sig*amp'*G^-1; | |
55 for p = 1:nb_pattern | |
56 n = norm(dico2(:,p)); | |
57 % renormalize | |
58 if n > 0 | |
59 dico(:,p) = dico2(:,p)/n; | |
60 amp(p,:) = amp(p,:)*n; | |
61 end | |
62 end | |
63 case 'ols' | |
64 for p = 1:nb_pattern | |
65 grad = res*amp(p,:)'; | |
66 if norm(grad) > 0 | |
67 pat = dico(:,p) + rho*grad; | |
68 pat = pat/norm(pat); | |
69 if nargin >5 && strcmp(flow, 'sequential') | |
70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | |
71 end | |
72 dico(:,p) = pat; | |
73 end | |
74 end | |
75 case 'mailhe' | |
76 for p = 1:nb_pattern | |
77 grad = res*amp(p,:)'; | |
78 if norm(grad) > 0 | |
79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; | |
80 pat = pat/norm(pat); | |
81 if nargin >5 && strcmp(flow, 'sequential') | |
82 res = res + (dico(:,p)-pat)*amp(p,:); | |
83 end | |
84 dico(:,p) = pat; | |
85 end | |
86 end | |
87 case 'KSVD' | |
88 for p = 1:nb_pattern | |
89 index = find(amp(p,:)~=0); | |
90 if ~isempty(index) | |
91 patch = res(:,index)+dico(:,p)*amp(p,index); | |
92 [U,S,V] = svd(patch); | |
93 if U(:,1)'*dico(:,p) > 0 | |
94 dico(:,p) = U(:,1); | |
95 else | |
96 dico(:,p) = -U(:,1); | |
97 end | |
98 dico(:,p) = dico(:,p)/norm(dico(:,p)); | |
99 amp(p,index) = dico(:,p)'*patch; | |
100 if nargin >5 && strcmp(flow, 'sequential') | |
101 res(:,index) = patch-dico(:,p)*amp(p,index); | |
102 end | |
103 end | |
104 end | |
105 end | |
106 end | |
107 |