comparison DL/two-step DL/dico_update.m @ 152:485747bf39e0 ivand_dev

Two step dictonary learning - Integration of the code for dictionary update and dictionary decorrelation from Boris Mailhe
author Ivan Damnjanovic lnx <ivan.damnjanovic@eecs.qmul.ac.uk>
date Thu, 28 Jul 2011 15:49:32 +0100
parents
children 9eb5f0d4c1a4 5140b0e06c22
comparison
equal deleted inserted replaced
149:fec205ec6ef6 152:485747bf39e0
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
2
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
5 %
6 % perform one iteration of dictionary update for dictionary learning
7 %
8 % parameters:
9 % - dico: the initial dictionary with atoms as columns
10 % - sig: the training data
11 % - amp: the amplitude coefficients as a sparse matrix
12 % - type: the algorithm can be one of the following
13 % - ols: fixed step gradient descent
14 % - mailhe: optimal step gradient descent (can be implemented as a
15 % default for ols?)
16 % - MOD: pseudo-inverse of the coefficients
17 % - KSVD: already implemented by Elad
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
19 % updated after each atom update. If parallel, the residual is only
20 % updated once the whole dictionary has been computed. Sequential works
21 % better, there may be no need to implement parallel. Not used with
22 % MOD.
23 % - rho: learning rate. If the type is 'ols', it is the descent step of
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the
25 % descent step is the optimal step*rho (typical choice: 1, although 2
26 % or 3 seems to work better). Not used for MOD and KSVD.
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
28 if ~ exist( 'rho', 'var' ) || isempty(rho)
29 rho = 0.1;
30 end
31
32 if ~ exist( 'flow', 'var' ) || isempty(flow)
33 flow = sequential;
34 end
35
36 res = sig - dico*amp;
37 nb_pattern = size(dico, 2);
38
39 switch type
40 case 'rand'
41 x = rand();
42 if x < 1/3
43 type = 'MOD';
44 elseif type < 2/3
45 type = 'mailhe';
46 else
47 type = 'KSVD';
48 end
49 end
50
51 switch type
52 case 'MOD'
53 G = amp*amp';
54 dico2 = sig*amp'*G^-1;
55 for p = 1:nb_pattern
56 n = norm(dico2(:,p));
57 % renormalize
58 if n > 0
59 dico(:,p) = dico2(:,p)/n;
60 amp(p,:) = amp(p,:)*n;
61 end
62 end
63 case 'ols'
64 for p = 1:nb_pattern
65 grad = res*amp(p,:)';
66 if norm(grad) > 0
67 pat = dico(:,p) + rho*grad;
68 pat = pat/norm(pat);
69 if nargin >5 && strcmp(flow, 'sequential')
70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
71 end
72 dico(:,p) = pat;
73 end
74 end
75 case 'mailhe'
76 for p = 1:nb_pattern
77 grad = res*amp(p,:)';
78 if norm(grad) > 0
79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad;
80 pat = pat/norm(pat);
81 if nargin >5 && strcmp(flow, 'sequential')
82 res = res + (dico(:,p)-pat)*amp(p,:);
83 end
84 dico(:,p) = pat;
85 end
86 end
87 case 'KSVD'
88 for p = 1:nb_pattern
89 index = find(amp(p,:)~=0);
90 if ~isempty(index)
91 patch = res(:,index)+dico(:,p)*amp(p,index);
92 [U,S,V] = svd(patch);
93 if U(:,1)'*dico(:,p) > 0
94 dico(:,p) = U(:,1);
95 else
96 dico(:,p) = -U(:,1);
97 end
98 dico(:,p) = dico(:,p)/norm(dico(:,p));
99 amp(p,index) = dico(:,p)'*patch;
100 if nargin >5 && strcmp(flow, 'sequential')
101 res(:,index) = patch-dico(:,p)*amp(p,index);
102 end
103 end
104 end
105 end
106 end
107