Mercurial > hg > smallbox
comparison util/classes/dictionaryMatrices/rotatematrix.m @ 170:68fb71aa5339 danieleb
Added dictionary decorrelation functions and test script for Letters paper.
author | Daniele Barchiesi <daniele.barchiesi@eecs.qmul.ac.uk> |
---|---|
date | Thu, 06 Oct 2011 14:33:41 +0100 |
parents | 290cca7d3469 |
children | 9c41f87dead7 |
comparison
equal
deleted
inserted
replaced
169:290cca7d3469 | 170:68fb71aa5339 |
---|---|
3 % | 3 % |
4 % | 4 % |
5 % REFERENCE | 5 % REFERENCE |
6 % M.D. Plumbley, Geometrical Methods for Non-Negative ICA: Manifolds, Lie | 6 % M.D. Plumbley, Geometrical Methods for Non-Negative ICA: Manifolds, Lie |
7 % Groups and Toral Subalgebra, Neurocomputing | 7 % Groups and Toral Subalgebra, Neurocomputing |
8 | |
9 %% Parse inputs and set defaults | |
8 if ~nargin, testrotatematrix; return, end | 10 if ~nargin, testrotatematrix; return, end |
9 | 11 |
12 if ~exist('param','var') || isempty(param), param = struct; end | |
13 if ~exist('method','var') || isempty(method), method = 'conjgradLie'; end | |
14 if ~isfield(param,'nIter'), param.nIter = 100; end %number of iterations | |
15 if ~isfield(param,'eps'), param.eps = 1e-9; end %tolerance level | |
16 if ~isfield(param,'step'), param.step = 0.01; end | |
10 | 17 |
11 if ~exist('method','var') || isempty(method), method = 'unconstrained'; end | 18 J = @(W) 0.5*norm(D-W*Phi,'fro'); %cost function |
12 | 19 |
13 J = @(W) 0.5*norm(D-W*Phi,'fro'); | 20 % Initialise variables |
14 cost = zeros(param.nIter,1); | 21 cost = zeros(param.nIter,1); %cost at each iteration |
22 W = eye(size(Phi,1)); %rotation matrix | |
23 grad = ones(size(W)); %gradient | |
24 t = param.step; %step size | |
25 Gprev = 0; %previous gradient | |
26 Hprev = 0; %previous Lie search direction | |
27 iIter = 1; %iteration counter | |
15 | 28 |
16 W = eye(size(Phi,1)); | 29 %% Main algorithm |
17 t = 0; | 30 while iIter<=param.nIter && norm(grad,'fro')>eps |
18 Gprev = 0; | 31 cost(iIter) = J(W); %calculate cost |
19 Hprev = 0; | 32 grad = (W*Phi-D)*Phi'; %calculate gradient |
20 for i=1:param.nIter | |
21 cost(i) = J(W); | |
22 grad = (W*Phi-D)*Phi'; | |
23 switch method | 33 switch method |
24 case 'unconstrained' % gradient descent | 34 case 'unconstrained' % gradient descent |
25 eta = param.step; | 35 eta = param.step; |
26 W = W - eta*grad; % update W by steepest descent | 36 W = W - eta*grad; % update W by steepest descent |
27 case 'tangent' % self correcting tangent | 37 case 'tangent' % self correcting tangent |
28 eta = param.step; | 38 eta = param.step; |
29 mu = param.reg; | 39 W = W - 0.5*eta*(grad - W*grad'*W); |
30 W = W - 0.5*eta*(grad - W*grad'*W + mu*W*(W'*W-eye(size(W)))); | 40 [U , ~, V] = svd(W); |
31 case 'steepestlie' | 41 W = U*V'; |
42 case 'steepestlie' %steepest descent in Lie algebra | |
32 eta = param.step; | 43 eta = param.step; |
33 B = 2*skew(grad*W'); % calculate gradient in lie algebra | 44 B = 2*skew(grad*W'); % calculate gradient in Lie algebra |
34 W = expm(-eta*B)*W; % update W by steepest descent | 45 W = expm(-eta*B)*W; % update W by steepest descent |
35 case 'linesearchlie' | 46 case 'linesearchlie' % line search in Lie algebra |
36 B = 2*skew(grad*W'); % calculate gradient in lie algebra | 47 B = 2*skew(grad*W'); % calculate gradient in Lie algebra |
37 H = -B; % calculate direction as negative gradient | 48 H = -B; % calculate direction as negative gradient |
38 t = searchline(J,H,W,t);% line search in one-parameter lie subalgebra | 49 t = searchline(J,H,W,t);% line search in one-parameter Lie subalgebra |
39 W = expm(t*H)*W; % update W by line search | 50 W = expm(t*H)*W; % update W by line search |
40 case 'conjgradlie' | 51 case 'conjgradlie' % conjugate gradient in Lie algebra |
41 G = 2*skew(grad*W'); % calculate gradient in lie algebra | 52 G = 2*skew(grad*W'); % calculate gradient in Lie algebra |
42 H = -G + polakRibiere(G,Gprev)*Hprev; %calculate conjugate gradient direction | 53 H = -G + polakRibiere(G,Gprev)*Hprev; %calculate conjugate gradient direction |
43 t = searchline(J,H,W,t);% line search in one-parameter lie subalgebra | 54 t = searchline(J,H,W,t);% line search in one-parameter Lie subalgebra |
44 W = expm(t*H)*W; % update W by line search | 55 W = expm(t*H)*W; % update W by line search |
45 Hprev = H; % % save search direction | 56 Hprev = H; % save search direction |
46 Gprev = G; % % save gradient | 57 Gprev = G; % save gradient |
47 end | 58 end |
59 iIter = iIter+1; % update iteration counter | |
48 end | 60 end |
49 Dhat = W*Phi; | 61 Dhat = W*Phi; %rotate matrix |
62 cost(iIter:end) = cost(iIter-1); %pad cost vector | |
50 end | 63 end |
51 % function C = matcomm(A,B) | |
52 % %Matrix commutator | |
53 % C = A*B-B*A; | |
54 | 64 |
65 %% Support functions | |
55 function gamma = polakRibiere(G,Gprev) | 66 function gamma = polakRibiere(G,Gprev) |
67 %Polak-Ribiere rule for conjugate direction calculation | |
56 gamma = G(:)'*(G(:)-Gprev(:))/(norm(Gprev(:))^2); | 68 gamma = G(:)'*(G(:)-Gprev(:))/(norm(Gprev(:))^2); |
57 if isnan(gamma) || isinf(gamma) | 69 if isnan(gamma) || isinf(gamma) |
58 gamma = 0; | 70 gamma = 0; |
59 end | 71 end |
60 end | 72 end |
61 | 73 |
62 function t = searchline(J,H,W,t) | 74 function t = searchline(J,H,W,t) |
75 %Line search in one-parameter Lie subalgebra | |
63 t = fminsearch(@(x) J(expm(x*H)*W),t); | 76 t = fminsearch(@(x) J(expm(x*H)*W),t); |
64 end | 77 end |
65 | 78 |
66 function B = skew(A) | 79 function B = skew(A) |
80 %Skew-symmetric matrix | |
67 B = 0.5*(A - A'); | 81 B = 0.5*(A - A'); |
68 end | 82 end |
69 | 83 |
70 | 84 |
85 %% Test function | |
71 function testrotatematrix | 86 function testrotatematrix |
72 clear, clc, close all | 87 clear, clc, close all |
73 n = 256; | 88 n = 256; %ambient dimension |
74 m = 512; | 89 m = 512; %number of atoms |
75 disp('A random matrix...'); | 90 param.nIter = 300; %number of iterations |
76 Phi = randn(n,m); | 91 param.step = 0.001; %step size |
77 disp('And its rotated mate...'); | 92 param.mu = 0.01; %regularization factor (for tangent method) |
78 Qtrue = expm(skew(randn(n))); | 93 methods = {'unconstrained','tangent','linesearchlie','conjgradlie'}; |
79 D = Qtrue*Phi; | |
80 disp('Now, lets try to find the right rotation...'); | |
81 param.nIter = 1000; | |
82 param.step = 0.001; | |
83 | 94 |
84 cost = zeros(param.nIter,4); | 95 Phi = randn(n,m); %initial dictionary |
85 [~, cost(:,1)] = rotatematrix(D,Phi,'unconstrained',param); | 96 Qtrue = expm(skew(randn(n))); %rotation matrix |
86 [~, cost(:,2)] = rotatematrix(D,Phi,'steepestlie',param); | 97 D = Qtrue*Phi; %target dictionary |
87 [~, cost(:,3)] = rotatematrix(D,Phi,'linesearchlie',param); | 98 |
88 [~, cost(:,4)] = rotatematrix(D,Phi,'conjgradlie',param); | 99 cost = zeros(param.nIter,length(methods)); |
100 for iIter=1:length(methods) | |
101 tic | |
102 [~, cost(:,iIter)] = rotatematrix(D,Phi,methods{iIter},param); | |
103 time = toc; | |
104 sprintf('Method %s completed in %f seconds \n',methods{iIter},time) | |
105 end | |
89 | 106 |
90 figure, plot(cost) | 107 figure, plot(cost) |
91 set(gca,'XScale','log','Yscale','log') | 108 set(gca,'XScale','log','Yscale','log') |
92 legend({'uncons','settpestlie','linesearchlie','conjgradlie'}) | 109 legend(methods) |
93 grid on | 110 grid on |
94 xlabel('number of iterations') | 111 xlabel('number of iterations') |
95 ylabel('J(W)') | 112 ylabel('J(W)') |
96 end | 113 end |