comparison DL/two-step DL/dico_update.m @ 201:5140b0e06c22 luisf_dev

Added separate dictionary decorrelation
author bmailhe
date Tue, 20 Mar 2012 14:50:35 +0000
parents 485747bf39e0
children 233e75809e4a
comparison
equal deleted inserted replaced
200:69ce11724b1f 201:5140b0e06c22
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
2 2
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
5 % 5 %
6 % perform one iteration of dictionary update for dictionary learning 6 % perform one iteration of dictionary update for dictionary learning
7 % 7 %
8 % parameters: 8 % parameters:
9 % - dico: the initial dictionary with atoms as columns 9 % - dico: the initial dictionary with atoms as columns
10 % - sig: the training data 10 % - sig: the training data
11 % - amp: the amplitude coefficients as a sparse matrix 11 % - amp: the amplitude coefficients as a sparse matrix
12 % - type: the algorithm can be one of the following 12 % - type: the algorithm can be one of the following
13 % - ols: fixed step gradient descent 13 % - ols: fixed step gradient descent, as described in Olshausen &
14 % - mailhe: optimal step gradient descent (can be implemented as a 14 % Field95
15 % default for ols?) 15 % - opt: optimal step gradient descent, as described in Mailhe et
16 % - MOD: pseudo-inverse of the coefficients 16 % al.08
17 % - KSVD: already implemented by Elad 17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99
18 % - KSVD: PCA update as described in Aharon06. For fast applications,
19 % use KSVDbox rather than this code.
20 % - LGD: large step gradient descent. Equivalent to 'opt' with
21 % rho=2.
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is 22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
19 % updated after each atom update. If parallel, the residual is only 23 % updated after each atom update. If parallel, the residual is only
20 % updated once the whole dictionary has been computed. Sequential works 24 % updated once the whole dictionary has been computed.
21 % better, there may be no need to implement parallel. Not used with 25 % Default: Sequential (sequential usually works better). Not used with
22 % MOD. 26 % MOD.
23 % - rho: learning rate. If the type is 'ols', it is the descent step of 27 % - rho: learning rate. If the type is 'ols', it is the descent step of
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the 28 % the gradient (default: 0.1). If the type is 'opt', the
25 % descent step is the optimal step*rho (typical choice: 1, although 2 29 % descent step is the optimal step*rho (default: 1, although 2 works
26 % or 3 seems to work better). Not used for MOD and KSVD. 30 % better. See LGD for more details). Not used for MOD and KSVD.
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
28 if ~ exist( 'rho', 'var' ) || isempty(rho)
29 rho = 0.1;
30 end
31 32
32 if ~ exist( 'flow', 'var' ) || isempty(flow) 33 if ~ exist( 'flow', 'var' ) || isempty(flow)
33 flow = sequential; 34 flow = 'sequential';
34 end 35 end
35 36
36 res = sig - dico*amp; 37 res = sig - dico*amp;
37 nb_pattern = size(dico, 2); 38 nb_pattern = size(dico, 2);
38 39
40 % if the type is random, then randomly pick another type
39 switch type 41 switch type
40 case 'rand' 42 case 'rand'
41 x = rand(); 43 x = rand();
42 if x < 1/3 44 if x < 1/3
43 type = 'MOD'; 45 type = 'MOD';
44 elseif type < 2/3 46 elseif type < 2/3
45 type = 'mailhe'; 47 type = 'opt';
46 else 48 else
47 type = 'KSVD'; 49 type = 'KSVD';
48 end 50 end
51 end
52
53 % set the learning rate to default if not provided
54 if ~ exist( 'rho', 'var' ) || isempty(rho)
55 switch type
56 case 'ols'
57 rho = 0.1;
58 case 'opt'
59 rho = 1;
60 end
49 end 61 end
50 62
51 switch type 63 switch type
52 case 'MOD' 64 case 'MOD'
53 G = amp*amp'; 65 G = amp*amp';
70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> 82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
71 end 83 end
72 dico(:,p) = pat; 84 dico(:,p) = pat;
73 end 85 end
74 end 86 end
75 case 'mailhe' 87 case 'opt'
76 for p = 1:nb_pattern 88 for p = 1:nb_pattern
77 grad = res*amp(p,:)'; 89 vec : amp(p,:);
90 grad = res*vec';
78 if norm(grad) > 0 91 if norm(grad) > 0
79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; 92 pat = (vec*vec')*dico(:,p) + rho*grad;
80 pat = pat/norm(pat); 93 pat = pat/norm(pat);
81 if nargin >5 && strcmp(flow, 'sequential') 94 if nargin >5 && strcmp(flow, 'sequential')
82 res = res + (dico(:,p)-pat)*amp(p,:); 95 res = res + (dico(:,p)-pat)*vec;
83 end 96 end
84 dico(:,p) = pat; 97 dico(:,p) = pat;
85 end 98 end
86 end 99 end
87 case 'KSVD' 100 case 'KSVD'
88 for p = 1:nb_pattern 101 for p = 1:nb_pattern
89 index = find(amp(p,:)~=0); 102 index = find(amp(p,:)~=0);
90 if ~isempty(index) 103 if ~isempty(index)
91 patch = res(:,index)+dico(:,p)*amp(p,index); 104 patch = res(:,index)+dico(:,p)*amp(p,index);
92 [U,S,V] = svd(patch); 105 [U,~,V] = svd(patch);
93 if U(:,1)'*dico(:,p) > 0 106 if U(:,1)'*dico(:,p) > 0
94 dico(:,p) = U(:,1); 107 dico(:,p) = U(:,1);
95 else 108 else
96 dico(:,p) = -U(:,1); 109 dico(:,p) = -U(:,1);
97 end 110 end