Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/glmhess.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [h, hdata] = glmhess(net, x, t, hdata) | |
2 %GLMHESS Evaluate the Hessian matrix for a generalised linear model. | |
3 % | |
4 % Description | |
5 % H = GLMHESS(NET, X, T) takes a GLM network data structure NET, a | |
6 % matrix X of input values, and a matrix T of target values and returns | |
7 % the full Hessian matrix H corresponding to the second derivatives of | |
8 % the negative log posterior distribution, evaluated for the current | |
9 % weight and bias values as defined by NET. Note that the target data | |
10 % is not required in the calculation, but is included to make the | |
11 % interface uniform with NETHESS. For linear and logistic outputs, the | |
12 % computation is very simple and is done (in effect) in one line in | |
13 % GLMTRAIN. | |
14 % | |
15 % [H, HDATA] = GLMHESS(NET, X, T) returns both the Hessian matrix H and | |
16 % the contribution HDATA arising from the data dependent term in the | |
17 % Hessian. | |
18 % | |
19 % H = GLMHESS(NET, X, T, HDATA) takes a network data structure NET, a | |
20 % matrix X of input values, and a matrix T of target values, together | |
21 % with the contribution HDATA arising from the data dependent term in | |
22 % the Hessian, and returns the full Hessian matrix H corresponding to | |
23 % the second derivatives of the negative log posterior distribution. | |
24 % This version saves computation time if HDATA has already been | |
25 % evaluated for the current weight and bias values. | |
26 % | |
27 % See also | |
28 % GLM, GLMTRAIN, HESSCHEK, NETHESS | |
29 % | |
30 | |
31 % Copyright (c) Ian T Nabney (1996-2001) | |
32 | |
33 % Check arguments for consistency | |
34 errstring = consist(net, 'glm', x, t); | |
35 if ~isempty(errstring); | |
36 error(errstring); | |
37 end | |
38 | |
39 ndata = size(x, 1); | |
40 nparams = net.nwts; | |
41 nout = net.nout; | |
42 p = glmfwd(net, x); | |
43 inputs = [x ones(ndata, 1)]; | |
44 | |
45 if nargin == 3 | |
46 hdata = zeros(nparams); % Full Hessian matrix | |
47 % Calculate data component of Hessian | |
48 switch net.outfn | |
49 | |
50 case 'linear' | |
51 % No weighting function here | |
52 out_hess = [x ones(ndata, 1)]'*[x ones(ndata, 1)]; | |
53 for j = 1:nout | |
54 hdata = rearrange_hess(net, j, out_hess, hdata); | |
55 end | |
56 case 'logistic' | |
57 % Each output is independent | |
58 e = ones(1, net.nin+1); | |
59 link_deriv = p.*(1-p); | |
60 out_hess = zeros(net.nin+1); | |
61 for j = 1:nout | |
62 inputs = [x ones(ndata, 1)].*(sqrt(link_deriv(:,j))*e); | |
63 out_hess = inputs'*inputs; % Hessian for this output | |
64 hdata = rearrange_hess(net, j, out_hess, hdata); | |
65 end | |
66 | |
67 case 'softmax' | |
68 bb_start = nparams - nout + 1; % Start of bias weights block | |
69 ex_hess = zeros(nparams); % Contribution to Hessian from single example | |
70 for m = 1:ndata | |
71 X = x(m,:)'*x(m,:); | |
72 a = diag(p(m,:))-((p(m,:)')*p(m,:)); | |
73 ex_hess(1:nparams-nout,1:nparams-nout) = kron(a, X); | |
74 ex_hess(bb_start:nparams, bb_start:nparams) = a.*ones(net.nout, net.nout); | |
75 temp = kron(a, x(m,:)); | |
76 ex_hess(bb_start:nparams, 1:nparams-nout) = temp; | |
77 ex_hess(1:nparams-nout, bb_start:nparams) = temp'; | |
78 hdata = hdata + ex_hess; | |
79 end | |
80 otherwise | |
81 error(['Unknown activation function ', net.outfn]); | |
82 end | |
83 end | |
84 | |
85 [h, hdata] = hbayes(net, hdata); | |
86 | |
87 function hdata = rearrange_hess(net, j, out_hess, hdata) | |
88 | |
89 % Because all the biases come after all the input weights, | |
90 % we have to rearrange the blocks that make up the network Hessian. | |
91 % This function assumes that we are on the jth output and that all outputs | |
92 % are independent. | |
93 | |
94 bb_start = net.nwts - net.nout + 1; % Start of bias weights block | |
95 ob_start = 1+(j-1)*net.nin; % Start of weight block for jth output | |
96 ob_end = j*net.nin; % End of weight block for jth output | |
97 b_index = bb_start+(j-1); % Index of bias weight | |
98 % Put input weight block in right place | |
99 hdata(ob_start:ob_end, ob_start:ob_end) = out_hess(1:net.nin, 1:net.nin); | |
100 % Put second derivative of bias weight in right place | |
101 hdata(b_index, b_index) = out_hess(net.nin+1, net.nin+1); | |
102 % Put cross terms (input weight v bias weight) in right place | |
103 hdata(b_index, ob_start:ob_end) = out_hess(net.nin+1,1:net.nin); | |
104 hdata(ob_start:ob_end, b_index) = out_hess(1:net.nin, net.nin+1); | |
105 | |
106 return |