daniele@169
|
1 function [Dhat cost W] = rotatematrix(D,Phi,method,param)
|
daniele@169
|
2 %
|
daniele@169
|
3 %
|
daniele@169
|
4 %
|
daniele@169
|
5 % REFERENCE
|
daniele@169
|
6 % M.D. Plumbley, Geometrical Methods for Non-Negative ICA: Manifolds, Lie
|
daniele@169
|
7 % Groups and Toral Subalgebra, Neurocomputing
|
daniele@169
|
8 if ~nargin, testrotatematrix; return, end
|
daniele@169
|
9
|
daniele@169
|
10
|
daniele@169
|
11 if ~exist('method','var') || isempty(method), method = 'unconstrained'; end
|
daniele@169
|
12
|
daniele@169
|
13 J = @(W) 0.5*norm(D-W*Phi,'fro');
|
daniele@169
|
14 cost = zeros(param.nIter,1);
|
daniele@169
|
15
|
daniele@169
|
16 W = eye(size(Phi,1));
|
daniele@169
|
17 t = 0;
|
daniele@169
|
18 Gprev = 0;
|
daniele@169
|
19 Hprev = 0;
|
daniele@169
|
20 for i=1:param.nIter
|
daniele@169
|
21 cost(i) = J(W);
|
daniele@169
|
22 grad = (W*Phi-D)*Phi';
|
daniele@169
|
23 switch method
|
daniele@169
|
24 case 'unconstrained' % gradient descent
|
daniele@169
|
25 eta = param.step;
|
daniele@169
|
26 W = W - eta*grad; % update W by steepest descent
|
daniele@169
|
27 case 'tangent' % self correcting tangent
|
daniele@169
|
28 eta = param.step;
|
daniele@169
|
29 mu = param.reg;
|
daniele@169
|
30 W = W - 0.5*eta*(grad - W*grad'*W + mu*W*(W'*W-eye(size(W))));
|
daniele@169
|
31 case 'steepestlie'
|
daniele@169
|
32 eta = param.step;
|
daniele@169
|
33 B = 2*skew(grad*W'); % calculate gradient in lie algebra
|
daniele@169
|
34 W = expm(-eta*B)*W; % update W by steepest descent
|
daniele@169
|
35 case 'linesearchlie'
|
daniele@169
|
36 B = 2*skew(grad*W'); % calculate gradient in lie algebra
|
daniele@169
|
37 H = -B; % calculate direction as negative gradient
|
daniele@169
|
38 t = searchline(J,H,W,t);% line search in one-parameter lie subalgebra
|
daniele@169
|
39 W = expm(t*H)*W; % update W by line search
|
daniele@169
|
40 case 'conjgradlie'
|
daniele@169
|
41 G = 2*skew(grad*W'); % calculate gradient in lie algebra
|
daniele@169
|
42 H = -G + polakRibiere(G,Gprev)*Hprev; %calculate conjugate gradient direction
|
daniele@169
|
43 t = searchline(J,H,W,t);% line search in one-parameter lie subalgebra
|
daniele@169
|
44 W = expm(t*H)*W; % update W by line search
|
daniele@169
|
45 Hprev = H; % % save search direction
|
daniele@169
|
46 Gprev = G; % % save gradient
|
daniele@169
|
47 end
|
daniele@169
|
48 end
|
daniele@169
|
49 Dhat = W*Phi;
|
daniele@169
|
50 end
|
daniele@169
|
51 % function C = matcomm(A,B)
|
daniele@169
|
52 % %Matrix commutator
|
daniele@169
|
53 % C = A*B-B*A;
|
daniele@169
|
54
|
daniele@169
|
55 function gamma = polakRibiere(G,Gprev)
|
daniele@169
|
56 gamma = G(:)'*(G(:)-Gprev(:))/(norm(Gprev(:))^2);
|
daniele@169
|
57 if isnan(gamma) || isinf(gamma)
|
daniele@169
|
58 gamma = 0;
|
daniele@169
|
59 end
|
daniele@169
|
60 end
|
daniele@169
|
61
|
daniele@169
|
62 function t = searchline(J,H,W,t)
|
daniele@169
|
63 t = fminsearch(@(x) J(expm(x*H)*W),t);
|
daniele@169
|
64 end
|
daniele@169
|
65
|
daniele@169
|
66 function B = skew(A)
|
daniele@169
|
67 B = 0.5*(A - A');
|
daniele@169
|
68 end
|
daniele@169
|
69
|
daniele@169
|
70
|
daniele@169
|
71 function testrotatematrix
|
daniele@169
|
72 clear, clc, close all
|
daniele@169
|
73 n = 256;
|
daniele@169
|
74 m = 512;
|
daniele@169
|
75 disp('A random matrix...');
|
daniele@169
|
76 Phi = randn(n,m);
|
daniele@169
|
77 disp('And its rotated mate...');
|
daniele@169
|
78 Qtrue = expm(skew(randn(n)));
|
daniele@169
|
79 D = Qtrue*Phi;
|
daniele@169
|
80 disp('Now, lets try to find the right rotation...');
|
daniele@169
|
81 param.nIter = 1000;
|
daniele@169
|
82 param.step = 0.001;
|
daniele@169
|
83
|
daniele@169
|
84 cost = zeros(param.nIter,4);
|
daniele@169
|
85 [~, cost(:,1)] = rotatematrix(D,Phi,'unconstrained',param);
|
daniele@169
|
86 [~, cost(:,2)] = rotatematrix(D,Phi,'steepestlie',param);
|
daniele@169
|
87 [~, cost(:,3)] = rotatematrix(D,Phi,'linesearchlie',param);
|
daniele@169
|
88 [~, cost(:,4)] = rotatematrix(D,Phi,'conjgradlie',param);
|
daniele@169
|
89
|
daniele@169
|
90 figure, plot(cost)
|
daniele@169
|
91 set(gca,'XScale','log','Yscale','log')
|
daniele@169
|
92 legend({'uncons','settpestlie','linesearchlie','conjgradlie'})
|
daniele@169
|
93 grid on
|
daniele@169
|
94 xlabel('number of iterations')
|
daniele@169
|
95 ylabel('J(W)')
|
daniele@169
|
96 end
|