annotate DL/two-step DL/dico_update.m @ 207:233e75809e4a luisf_dev

Accelerated the code for LGD and optimal grandient descent
author bmailhe
date Wed, 21 Mar 2012 14:12:25 +0000
parents 5140b0e06c22
children fd0b5d36f6ad
rev   line source
ivan@152 1 function [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
bmailhe@201 2
ivan@152 3 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ivan@152 4 % [dico, amp] = dico_update(dico, sig, amp, type, flow, rho)
ivan@152 5 %
ivan@152 6 % perform one iteration of dictionary update for dictionary learning
ivan@152 7 %
ivan@152 8 % parameters:
ivan@152 9 % - dico: the initial dictionary with atoms as columns
ivan@152 10 % - sig: the training data
ivan@152 11 % - amp: the amplitude coefficients as a sparse matrix
ivan@152 12 % - type: the algorithm can be one of the following
bmailhe@201 13 % - ols: fixed step gradient descent, as described in Olshausen &
bmailhe@201 14 % Field95
bmailhe@201 15 % - opt: optimal step gradient descent, as described in Mailhe et
bmailhe@201 16 % al.08
bmailhe@201 17 % - MOD: pseudo-inverse of the coefficients, as described in Engan99
bmailhe@201 18 % - KSVD: PCA update as described in Aharon06. For fast applications,
bmailhe@201 19 % use KSVDbox rather than this code.
bmailhe@201 20 % - LGD: large step gradient descent. Equivalent to 'opt' with
bmailhe@201 21 % rho=2.
ivan@152 22 % - flow: 'sequential' or 'parallel'. If sequential, the residual is
ivan@152 23 % updated after each atom update. If parallel, the residual is only
bmailhe@201 24 % updated once the whole dictionary has been computed.
bmailhe@201 25 % Default: Sequential (sequential usually works better). Not used with
ivan@152 26 % MOD.
ivan@152 27 % - rho: learning rate. If the type is 'ols', it is the descent step of
bmailhe@201 28 % the gradient (default: 0.1). If the type is 'opt', the
bmailhe@201 29 % descent step is the optimal step*rho (default: 1, although 2 works
bmailhe@201 30 % better. See LGD for more details). Not used for MOD and KSVD.
ivan@152 31 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ivan@152 32
ivan@152 33 if ~ exist( 'flow', 'var' ) || isempty(flow)
bmailhe@201 34 flow = 'sequential';
ivan@152 35 end
ivan@152 36
ivan@152 37 res = sig - dico*amp;
ivan@152 38 nb_pattern = size(dico, 2);
ivan@152 39
bmailhe@201 40 % if the type is random, then randomly pick another type
ivan@152 41 switch type
ivan@152 42 case 'rand'
ivan@152 43 x = rand();
ivan@152 44 if x < 1/3
ivan@152 45 type = 'MOD';
ivan@152 46 elseif type < 2/3
bmailhe@201 47 type = 'opt';
ivan@152 48 else
ivan@152 49 type = 'KSVD';
ivan@152 50 end
ivan@152 51 end
ivan@152 52
bmailhe@201 53 % set the learning rate to default if not provided
bmailhe@201 54 if ~ exist( 'rho', 'var' ) || isempty(rho)
bmailhe@201 55 switch type
bmailhe@201 56 case 'ols'
bmailhe@201 57 rho = 0.1;
bmailhe@201 58 case 'opt'
bmailhe@201 59 rho = 1;
bmailhe@201 60 end
bmailhe@201 61 end
bmailhe@201 62
ivan@152 63 switch type
ivan@152 64 case 'MOD'
ivan@152 65 G = amp*amp';
ivan@152 66 dico2 = sig*amp'*G^-1;
ivan@152 67 for p = 1:nb_pattern
ivan@152 68 n = norm(dico2(:,p));
ivan@152 69 % renormalize
ivan@152 70 if n > 0
ivan@152 71 dico(:,p) = dico2(:,p)/n;
ivan@152 72 amp(p,:) = amp(p,:)*n;
ivan@152 73 end
ivan@152 74 end
ivan@152 75 case 'ols'
ivan@152 76 for p = 1:nb_pattern
ivan@152 77 grad = res*amp(p,:)';
ivan@152 78 if norm(grad) > 0
ivan@152 79 pat = dico(:,p) + rho*grad;
ivan@152 80 pat = pat/norm(pat);
ivan@152 81 if nargin >5 && strcmp(flow, 'sequential')
ivan@152 82 res = res + (dico(:,p)-pat)*amp(p,:); %#ok<*NASGU>
ivan@152 83 end
ivan@152 84 dico(:,p) = pat;
ivan@152 85 end
ivan@152 86 end
bmailhe@201 87 case 'opt'
ivan@152 88 for p = 1:nb_pattern
bmailhe@207 89 index = find(amp(p,:)~=0);
bmailhe@207 90 vec = amp(p,index);
bmailhe@207 91 grad = res(:,index)*vec';
ivan@152 92 if norm(grad) > 0
bmailhe@201 93 pat = (vec*vec')*dico(:,p) + rho*grad;
ivan@152 94 pat = pat/norm(pat);
ivan@152 95 if nargin >5 && strcmp(flow, 'sequential')
bmailhe@207 96 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
bmailhe@207 97 end
bmailhe@207 98 dico(:,p) = pat;
bmailhe@207 99 end
bmailhe@207 100 end
bmailhe@207 101 case 'LGD'
bmailhe@207 102 for p = 1:nb_pattern
bmailhe@207 103 index = find(amp(p,:)~=0);
bmailhe@207 104 vec = amp(p,index);
bmailhe@207 105 grad = res(:,index)*vec';
bmailhe@207 106 if norm(grad) > 0
bmailhe@207 107 pat = (vec*vec')*dico(:,p) + 2*grad;
bmailhe@207 108 pat = pat/norm(pat);
bmailhe@207 109 if nargin >5 && strcmp(flow, 'sequential')
bmailhe@207 110 res(:,index) = res(:,index) + (dico(:,p)-pat)*vec;
ivan@152 111 end
ivan@152 112 dico(:,p) = pat;
ivan@152 113 end
ivan@152 114 end
ivan@152 115 case 'KSVD'
ivan@152 116 for p = 1:nb_pattern
ivan@152 117 index = find(amp(p,:)~=0);
ivan@152 118 if ~isempty(index)
ivan@152 119 patch = res(:,index)+dico(:,p)*amp(p,index);
bmailhe@201 120 [U,~,V] = svd(patch);
ivan@152 121 if U(:,1)'*dico(:,p) > 0
ivan@152 122 dico(:,p) = U(:,1);
ivan@152 123 else
ivan@152 124 dico(:,p) = -U(:,1);
ivan@152 125 end
ivan@152 126 dico(:,p) = dico(:,p)/norm(dico(:,p));
ivan@152 127 amp(p,index) = dico(:,p)'*patch;
ivan@152 128 if nargin >5 && strcmp(flow, 'sequential')
ivan@152 129 res(:,index) = patch-dico(:,p)*amp(p,index);
ivan@152 130 end
ivan@152 131 end
ivan@152 132 end
ivan@152 133 end
ivan@152 134 end
ivan@152 135