Mercurial > hg > smallbox
comparison DL/two-step DL/dico_update.m @ 201:5140b0e06c22 luisf_dev
Added separate dictionary decorrelation
author | bmailhe |
---|---|
date | Tue, 20 Mar 2012 14:50:35 +0000 |
parents | 485747bf39e0 |
children | 233e75809e4a |
comparison
equal
deleted
inserted
replaced
200:69ce11724b1f | 201:5140b0e06c22 |
---|---|
1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) |
2 | 2 |
3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) | 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho) |
5 % | 5 % |
6 % perform one iteration of dictionary update for dictionary learning | 6 % perform one iteration of dictionary update for dictionary learning |
7 % | 7 % |
8 % parameters: | 8 % parameters: |
9 % - dico: the initial dictionary with atoms as columns | 9 % - dico: the initial dictionary with atoms as columns |
10 % - sig: the training data | 10 % - sig: the training data |
11 % - amp: the amplitude coefficients as a sparse matrix | 11 % - amp: the amplitude coefficients as a sparse matrix |
12 % - type: the algorithm can be one of the following | 12 % - type: the algorithm can be one of the following |
13 % - ols: fixed step gradient descent | 13 % - ols: fixed step gradient descent, as described in Olshausen & |
14 % - mailhe: optimal step gradient descent (can be implemented as a | 14 % Field95 |
15 % default for ols?) | 15 % - opt: optimal step gradient descent, as described in Mailhe et |
16 % - MOD: pseudo-inverse of the coefficients | 16 % al.08 |
17 % - KSVD: already implemented by Elad | 17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99 |
18 % - KSVD: PCA update as described in Aharon06. For fast applications, | |
19 % use KSVDbox rather than this code. | |
20 % - LGD: large step gradient descent. Equivalent to 'opt' with | |
21 % rho=2. | |
18 % - flow: 'sequential' or 'parallel'. If sequential, the residual is | 22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is |
19 % updated after each atom update. If parallel, the residual is only | 23 % updated after each atom update. If parallel, the residual is only |
20 % updated once the whole dictionary has been computed. Sequential works | 24 % updated once the whole dictionary has been computed. |
21 % better, there may be no need to implement parallel. Not used with | 25 % Default: Sequential (sequential usually works better). Not used with |
22 % MOD. | 26 % MOD. |
23 % - rho: learning rate. If the type is 'ols', it is the descent step of | 27 % - rho: learning rate. If the type is 'ols', it is the descent step of |
24 % the gradient (typical choice: 0.1). If the type is 'mailhe', the | 28 % the gradient (default: 0.1). If the type is 'opt', the |
25 % descent step is the optimal step*rho (typical choice: 1, although 2 | 29 % descent step is the optimal step*rho (default: 1, although 2 works |
26 % or 3 seems to work better). Not used for MOD and KSVD. | 30 % better. See LGD for more details). Not used for MOD and KSVD. |
27 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | 31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
28 if ~ exist( 'rho', 'var' ) || isempty(rho) | |
29 rho = 0.1; | |
30 end | |
31 | 32 |
32 if ~ exist( 'flow', 'var' ) || isempty(flow) | 33 if ~ exist( 'flow', 'var' ) || isempty(flow) |
33 flow = sequential; | 34 flow = 'sequential'; |
34 end | 35 end |
35 | 36 |
36 res = sig - dico*amp; | 37 res = sig - dico*amp; |
37 nb_pattern = size(dico, 2); | 38 nb_pattern = size(dico, 2); |
38 | 39 |
40 % if the type is random, then randomly pick another type | |
39 switch type | 41 switch type |
40 case 'rand' | 42 case 'rand' |
41 x = rand(); | 43 x = rand(); |
42 if x < 1/3 | 44 if x < 1/3 |
43 type = 'MOD'; | 45 type = 'MOD'; |
44 elseif type < 2/3 | 46 elseif type < 2/3 |
45 type = 'mailhe'; | 47 type = 'opt'; |
46 else | 48 else |
47 type = 'KSVD'; | 49 type = 'KSVD'; |
48 end | 50 end |
51 end | |
52 | |
53 % set the learning rate to default if not provided | |
54 if ~ exist( 'rho', 'var' ) || isempty(rho) | |
55 switch type | |
56 case 'ols' | |
57 rho = 0.1; | |
58 case 'opt' | |
59 rho = 1; | |
60 end | |
49 end | 61 end |
50 | 62 |
51 switch type | 63 switch type |
52 case 'MOD' | 64 case 'MOD' |
53 G = amp*amp'; | 65 G = amp*amp'; |
70 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> | 82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU> |
71 end | 83 end |
72 dico(:,p) = pat; | 84 dico(:,p) = pat; |
73 end | 85 end |
74 end | 86 end |
75 case 'mailhe' | 87 case 'opt' |
76 for p = 1:nb_pattern | 88 for p = 1:nb_pattern |
77 grad = res*amp(p,:)'; | 89 vec : amp(p,:); |
90 grad = res*vec'; | |
78 if norm(grad) > 0 | 91 if norm(grad) > 0 |
79 pat = (amp(p,:)*amp(p,:)')*dico(:,p) + rho*grad; | 92 pat = (vec*vec')*dico(:,p) + rho*grad; |
80 pat = pat/norm(pat); | 93 pat = pat/norm(pat); |
81 if nargin >5 && strcmp(flow, 'sequential') | 94 if nargin >5 && strcmp(flow, 'sequential') |
82 res = res + (dico(:,p)-pat)*amp(p,:); | 95 res = res + (dico(:,p)-pat)*vec; |
83 end | 96 end |
84 dico(:,p) = pat; | 97 dico(:,p) = pat; |
85 end | 98 end |
86 end | 99 end |
87 case 'KSVD' | 100 case 'KSVD' |
88 for p = 1:nb_pattern | 101 for p = 1:nb_pattern |
89 index = find(amp(p,:)~=0); | 102 index = find(amp(p,:)~=0); |
90 if ~isempty(index) | 103 if ~isempty(index) |
91 patch = res(:,index)+dico(:,p)*amp(p,index); | 104 patch = res(:,index)+dico(:,p)*amp(p,index); |
92 [U,S,V] = svd(patch); | 105 [U,~,V] = svd(patch); |
93 if U(:,1)'*dico(:,p) > 0 | 106 if U(:,1)'*dico(:,p) > 0 |
94 dico(:,p) = U(:,1); | 107 dico(:,p) = U(:,1); |
95 else | 108 else |
96 dico(:,p) = -U(:,1); | 109 dico(:,p) = -U(:,1); |
97 end | 110 end |