wolffd@0
|
1 function [h, hdata] = hbayes(net, hdata)
|
wolffd@0
|
2 %HBAYES Evaluate Hessian of Bayesian error function for network.
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % Description
|
wolffd@0
|
5 % H = HBAYES(NET, HDATA) takes a network data structure NET together
|
wolffd@0
|
6 % the data contribution to the Hessian for a set of inputs and targets.
|
wolffd@0
|
7 % It returns the regularised Hessian using any zero mean Gaussian
|
wolffd@0
|
8 % priors on the weights defined in NET. In addition, if a MASK is
|
wolffd@0
|
9 % defined in NET, then the entries in H that correspond to weights with
|
wolffd@0
|
10 % a 0 in the mask are removed.
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % [H, HDATA] = HBAYES(NET, HDATA) additionally returns the data
|
wolffd@0
|
13 % component of the Hessian.
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % See also
|
wolffd@0
|
16 % GBAYES, GLMHESS, MLPHESS, RBFHESS
|
wolffd@0
|
17 %
|
wolffd@0
|
18
|
wolffd@0
|
19 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
20
|
wolffd@0
|
21 if (isfield(net, 'mask'))
|
wolffd@0
|
22 % Extract relevant entries in Hessian
|
wolffd@0
|
23 nmask_rows = size(find(net.mask), 1);
|
wolffd@0
|
24 hdata = reshape(hdata(logical(net.mask*(net.mask'))), ...
|
wolffd@0
|
25 nmask_rows, nmask_rows);
|
wolffd@0
|
26 nwts = nmask_rows;
|
wolffd@0
|
27 else
|
wolffd@0
|
28 nwts = net.nwts;
|
wolffd@0
|
29 end
|
wolffd@0
|
30 if isfield(net, 'beta')
|
wolffd@0
|
31 h = net.beta*hdata;
|
wolffd@0
|
32 else
|
wolffd@0
|
33 h = hdata;
|
wolffd@0
|
34 end
|
wolffd@0
|
35
|
wolffd@0
|
36 if isfield(net, 'alpha')
|
wolffd@0
|
37 if size(net.alpha) == [1 1]
|
wolffd@0
|
38 h = h + net.alpha*eye(nwts);
|
wolffd@0
|
39 else
|
wolffd@0
|
40 if isfield(net, 'mask')
|
wolffd@0
|
41 nindx_cols = size(net.index, 2);
|
wolffd@0
|
42 index = reshape(net.index(logical(repmat(net.mask, ...
|
wolffd@0
|
43 1, nindx_cols))), nmask_rows, nindx_cols);
|
wolffd@0
|
44 else
|
wolffd@0
|
45 index = net.index;
|
wolffd@0
|
46 end
|
wolffd@0
|
47 h = h + diag(index*net.alpha);
|
wolffd@0
|
48 end
|
wolffd@0
|
49 end
|