Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/evidence.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [net, gamma, logev] = evidence(net, x, t, num) | |
2 %EVIDENCE Re-estimate hyperparameters using evidence approximation. | |
3 % | |
4 % Description | |
5 % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA | |
6 % and BETA by applying Bayesian re-estimation formulae for NUM | |
7 % iterations. The hyperparameter ALPHA can be a simple scalar | |
8 % associated with an isotropic prior on the weights, or can be a vector | |
9 % in which each component is associated with a group of weights as | |
10 % defined by the INDEX matrix in the NET data structure. These more | |
11 % complex priors can be set up for an MLP using MLPPRIOR. Initial | |
12 % values for the iterative re-estimation are taken from the network | |
13 % data structure NET passed as an input argument, while the return | |
14 % argument NET contains the re-estimated values. | |
15 % | |
16 % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re- | |
17 % estimation formula to be applied for NUM cycles in which the re- | |
18 % estimated values for the hyperparameters from each cycle are used to | |
19 % re-evaluate the Hessian matrix for the next cycle. The return value | |
20 % GAMMA is the number of well-determined parameters and LOGEV is the | |
21 % log of the evidence. | |
22 % | |
23 % See also | |
24 % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD | |
25 % | |
26 | |
27 % Copyright (c) Ian T Nabney (1996-2001) | |
28 | |
29 errstring = consist(net, '', x, t); | |
30 if ~isempty(errstring) | |
31 error(errstring); | |
32 end | |
33 | |
34 ndata = size(x, 1); | |
35 if nargin == 3 | |
36 num = 1; | |
37 end | |
38 | |
39 % Extract weights from network | |
40 w = netpak(net); | |
41 | |
42 % Evaluate data-dependent contribution to the Hessian matrix. | |
43 [h, dh] = nethess(w, net, x, t); | |
44 clear h; % To save memory when Hessian is large | |
45 if (~isfield(net, 'beta')) | |
46 local_beta = 1; | |
47 end | |
48 | |
49 [evec, evl] = eig(dh); | |
50 % Now set the negative eigenvalues to zero. | |
51 evl = evl.*(evl > 0); | |
52 % safe_evl is used to avoid taking log of zero | |
53 safe_evl = evl + eps.*(evl <= 0); | |
54 | |
55 [e, edata, eprior] = neterr(w, net, x, t); | |
56 | |
57 if size(net.alpha) == [1 1] | |
58 % Form vector of eigenvalues | |
59 evl = diag(evl); | |
60 safe_evl = diag(safe_evl); | |
61 else | |
62 ngroups = size(net.alpha, 1); | |
63 gams = zeros(1, ngroups); | |
64 logas = zeros(1, ngroups); | |
65 % Reconstruct data hessian with negative eigenvalues set to zero. | |
66 dh = evec*evl*evec'; | |
67 end | |
68 | |
69 % Do the re-estimation. | |
70 for k = 1 : num | |
71 % Re-estimate alpha. | |
72 if size(net.alpha) == [1 1] | |
73 % Evaluate number of well-determined parameters. | |
74 L = evl; | |
75 if isfield(net, 'beta') | |
76 L = net.beta*L; | |
77 end | |
78 gamma = sum(L./(L + net.alpha)); | |
79 net.alpha = 0.5*gamma/eprior; | |
80 % Partially evaluate log evidence: only include unmasked weights | |
81 logev = 0.5*length(w)*log(net.alpha); | |
82 else | |
83 hinv = inv(hbayes(net, dh)); | |
84 for m = 1 : ngroups | |
85 group_nweights = sum(net.index(:, m)); | |
86 gams(m) = group_nweights - ... | |
87 net.alpha(m)*sum(diag(hinv).*net.index(:,m)); | |
88 net.alpha(m) = real(gams(m)/(2*eprior(m))); | |
89 % Weight alphas by number of weights in group | |
90 logas(m) = 0.5*group_nweights*log(net.alpha(m)); | |
91 end | |
92 gamma = sum(gams, 2); | |
93 logev = sum(logas); | |
94 end | |
95 % Re-estimate beta. | |
96 if isfield(net, 'beta') | |
97 net.beta = 0.5*(net.nout*ndata - gamma)/edata; | |
98 logev = logev + 0.5*ndata*log(net.beta) - 0.5*ndata*log(2*pi); | |
99 local_beta = net.beta; | |
100 end | |
101 | |
102 % Evaluate new log evidence | |
103 e = errbayes(net, edata); | |
104 if size(net.alpha) == [1 1] | |
105 logev = logev - e - 0.5*sum(log(local_beta*safe_evl+net.alpha)); | |
106 else | |
107 for m = 1:ngroups | |
108 logev = logev - e - ... | |
109 0.5*sum(log(local_beta*(safe_evl*net.index(:, m))+... | |
110 net.alpha(m))); | |
111 end | |
112 end | |
113 end | |
114 |