wolffd@0
|
1 function [net, gamma, logev] = evidence(net, x, t, num)
|
wolffd@0
|
2 %EVIDENCE Re-estimate hyperparameters using evidence approximation.
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % Description
|
wolffd@0
|
5 % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA
|
wolffd@0
|
6 % and BETA by applying Bayesian re-estimation formulae for NUM
|
wolffd@0
|
7 % iterations. The hyperparameter ALPHA can be a simple scalar
|
wolffd@0
|
8 % associated with an isotropic prior on the weights, or can be a vector
|
wolffd@0
|
9 % in which each component is associated with a group of weights as
|
wolffd@0
|
10 % defined by the INDEX matrix in the NET data structure. These more
|
wolffd@0
|
11 % complex priors can be set up for an MLP using MLPPRIOR. Initial
|
wolffd@0
|
12 % values for the iterative re-estimation are taken from the network
|
wolffd@0
|
13 % data structure NET passed as an input argument, while the return
|
wolffd@0
|
14 % argument NET contains the re-estimated values.
|
wolffd@0
|
15 %
|
wolffd@0
|
16 % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re-
|
wolffd@0
|
17 % estimation formula to be applied for NUM cycles in which the re-
|
wolffd@0
|
18 % estimated values for the hyperparameters from each cycle are used to
|
wolffd@0
|
19 % re-evaluate the Hessian matrix for the next cycle. The return value
|
wolffd@0
|
20 % GAMMA is the number of well-determined parameters and LOGEV is the
|
wolffd@0
|
21 % log of the evidence.
|
wolffd@0
|
22 %
|
wolffd@0
|
23 % See also
|
wolffd@0
|
24 % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD
|
wolffd@0
|
25 %
|
wolffd@0
|
26
|
wolffd@0
|
27 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
28
|
wolffd@0
|
29 errstring = consist(net, '', x, t);
|
wolffd@0
|
30 if ~isempty(errstring)
|
wolffd@0
|
31 error(errstring);
|
wolffd@0
|
32 end
|
wolffd@0
|
33
|
wolffd@0
|
34 ndata = size(x, 1);
|
wolffd@0
|
35 if nargin == 3
|
wolffd@0
|
36 num = 1;
|
wolffd@0
|
37 end
|
wolffd@0
|
38
|
wolffd@0
|
39 % Extract weights from network
|
wolffd@0
|
40 w = netpak(net);
|
wolffd@0
|
41
|
wolffd@0
|
42 % Evaluate data-dependent contribution to the Hessian matrix.
|
wolffd@0
|
43 [h, dh] = nethess(w, net, x, t);
|
wolffd@0
|
44 clear h; % To save memory when Hessian is large
|
wolffd@0
|
45 if (~isfield(net, 'beta'))
|
wolffd@0
|
46 local_beta = 1;
|
wolffd@0
|
47 end
|
wolffd@0
|
48
|
wolffd@0
|
49 [evec, evl] = eig(dh);
|
wolffd@0
|
50 % Now set the negative eigenvalues to zero.
|
wolffd@0
|
51 evl = evl.*(evl > 0);
|
wolffd@0
|
52 % safe_evl is used to avoid taking log of zero
|
wolffd@0
|
53 safe_evl = evl + eps.*(evl <= 0);
|
wolffd@0
|
54
|
wolffd@0
|
55 [e, edata, eprior] = neterr(w, net, x, t);
|
wolffd@0
|
56
|
wolffd@0
|
57 if size(net.alpha) == [1 1]
|
wolffd@0
|
58 % Form vector of eigenvalues
|
wolffd@0
|
59 evl = diag(evl);
|
wolffd@0
|
60 safe_evl = diag(safe_evl);
|
wolffd@0
|
61 else
|
wolffd@0
|
62 ngroups = size(net.alpha, 1);
|
wolffd@0
|
63 gams = zeros(1, ngroups);
|
wolffd@0
|
64 logas = zeros(1, ngroups);
|
wolffd@0
|
65 % Reconstruct data hessian with negative eigenvalues set to zero.
|
wolffd@0
|
66 dh = evec*evl*evec';
|
wolffd@0
|
67 end
|
wolffd@0
|
68
|
wolffd@0
|
69 % Do the re-estimation.
|
wolffd@0
|
70 for k = 1 : num
|
wolffd@0
|
71 % Re-estimate alpha.
|
wolffd@0
|
72 if size(net.alpha) == [1 1]
|
wolffd@0
|
73 % Evaluate number of well-determined parameters.
|
wolffd@0
|
74 L = evl;
|
wolffd@0
|
75 if isfield(net, 'beta')
|
wolffd@0
|
76 L = net.beta*L;
|
wolffd@0
|
77 end
|
wolffd@0
|
78 gamma = sum(L./(L + net.alpha));
|
wolffd@0
|
79 net.alpha = 0.5*gamma/eprior;
|
wolffd@0
|
80 % Partially evaluate log evidence: only include unmasked weights
|
wolffd@0
|
81 logev = 0.5*length(w)*log(net.alpha);
|
wolffd@0
|
82 else
|
wolffd@0
|
83 hinv = inv(hbayes(net, dh));
|
wolffd@0
|
84 for m = 1 : ngroups
|
wolffd@0
|
85 group_nweights = sum(net.index(:, m));
|
wolffd@0
|
86 gams(m) = group_nweights - ...
|
wolffd@0
|
87 net.alpha(m)*sum(diag(hinv).*net.index(:,m));
|
wolffd@0
|
88 net.alpha(m) = real(gams(m)/(2*eprior(m)));
|
wolffd@0
|
89 % Weight alphas by number of weights in group
|
wolffd@0
|
90 logas(m) = 0.5*group_nweights*log(net.alpha(m));
|
wolffd@0
|
91 end
|
wolffd@0
|
92 gamma = sum(gams, 2);
|
wolffd@0
|
93 logev = sum(logas);
|
wolffd@0
|
94 end
|
wolffd@0
|
95 % Re-estimate beta.
|
wolffd@0
|
96 if isfield(net, 'beta')
|
wolffd@0
|
97 net.beta = 0.5*(net.nout*ndata - gamma)/edata;
|
wolffd@0
|
98 logev = logev + 0.5*ndata*log(net.beta) - 0.5*ndata*log(2*pi);
|
wolffd@0
|
99 local_beta = net.beta;
|
wolffd@0
|
100 end
|
wolffd@0
|
101
|
wolffd@0
|
102 % Evaluate new log evidence
|
wolffd@0
|
103 e = errbayes(net, edata);
|
wolffd@0
|
104 if size(net.alpha) == [1 1]
|
wolffd@0
|
105 logev = logev - e - 0.5*sum(log(local_beta*safe_evl+net.alpha));
|
wolffd@0
|
106 else
|
wolffd@0
|
107 for m = 1:ngroups
|
wolffd@0
|
108 logev = logev - e - ...
|
wolffd@0
|
109 0.5*sum(log(local_beta*(safe_evl*net.index(:, m))+...
|
wolffd@0
|
110 net.alpha(m)));
|
wolffd@0
|
111 end
|
wolffd@0
|
112 end
|
wolffd@0
|
113 end
|
wolffd@0
|
114
|