Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlabKPM/evidence_weighted.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [net, gamma, logev] = evidence_weighted(net, x, t, eso_w, num) | |
2 %EVIDENCE Re-estimate hyperparameters using evidence approximation. | |
3 % | |
4 % Description | |
5 % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA | |
6 % and BETA by applying Bayesian re-estimation formulae for NUM | |
7 % iterations. The hyperparameter ALPHA can be a simple scalar | |
8 % associated with an isotropic prior on the weights, or can be a vector | |
9 % in which each component is associated with a group of weights as | |
10 % defined by the INDEX matrix in the NET data structure. These more | |
11 % complex priors can be set up for an MLP using MLPPRIOR. Initial | |
12 % values for the iterative re-estimation are taken from the network | |
13 % data structure NET passed as an input argument, while the return | |
14 % argument NET contains the re-estimated values. | |
15 % | |
16 % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re- | |
17 % estimation formula to be applied for NUM cycles in which the re- | |
18 % estimated values for the hyperparameters from each cycle are used to | |
19 % re-evaluate the Hessian matrix for the next cycle. The return value | |
20 % GAMMA is the number of well-determined parameters and LOGEV is the | |
21 % log of the evidence. | |
22 % | |
23 % See also | |
24 % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD | |
25 % | |
26 | |
27 % Copyright (c) Ian T Nabney (1996-9) | |
28 | |
29 errstring = consist(net, '', x, t); | |
30 if ~isempty(errstring) | |
31 error(errstring); | |
32 end | |
33 | |
34 ndata = size(x, 1); | |
35 if nargin == 4 | |
36 num = 1; | |
37 end | |
38 | |
39 if isfield(net,'beta') | |
40 beta = net.beta; | |
41 else | |
42 beta = 1; | |
43 end; | |
44 | |
45 % Extract weights from network | |
46 pakstr = [net.type, 'pak']; | |
47 w = feval(pakstr, net); | |
48 | |
49 % Evaluate data-dependent contribution to the Hessian matrix. | |
50 [h, dh] = nethess_weighted(w, net, x, t, eso_w); | |
51 | |
52 % Now set the negative eigenvalues to zero. | |
53 [evec, evl] = eig(dh); | |
54 evl = evl.*(evl > 0); | |
55 % safe_evl is used to avoid taking log of zero | |
56 safe_evl = evl + eps.*(evl <= 0); | |
57 | |
58 % Do the re-estimation. | |
59 for k = 1 : num | |
60 [e, edata, eprior] = neterr_weighted(w, net, x, t, eso_w); | |
61 h = nethess_weighted(w, net, x, t, eso_w, dh); | |
62 % Re-estimate alpha. | |
63 if size(net.alpha) == [1 1] | |
64 % Evaluate number of well-determined parameters. | |
65 if k == 1 | |
66 % Form vector of eigenvalues | |
67 evl = diag(evl); | |
68 safe_evl = diag(safe_evl); | |
69 end | |
70 B = beta*evl; | |
71 gamma = sum(B./(B + net.alpha)); | |
72 net.alpha = 0.5*gamma/eprior; | |
73 | |
74 % Partially evaluate log evidence | |
75 logev = e - 0.5*sum(log(safe_evl)) + 0.5*net.nwts*log(net.alpha) - ... | |
76 0.5*ndata*log(2*pi); | |
77 else | |
78 ngroups = size(net.alpha, 1); | |
79 gams = zeros(1, ngroups); | |
80 logas = zeros(1, ngroups); | |
81 traces = zeros(1, ngroups); | |
82 % Reconstruct data hessian with negative eigenvalues set to zero. | |
83 dh = evec*evl*evec'; | |
84 hinv = inv(nethess_weighted(w, net, x, t, eso_w, dh)); | |
85 for m = 1 : ngroups | |
86 group_nweights = sum(net.index(:, m)); | |
87 gams(m) = group_nweights - ... | |
88 net.alpha(m)*sum(diag(hinv).*net.index(:,m)); | |
89 net.alpha(m) = real(gams(m)/(2*eprior(m))); | |
90 % Weight alphas by number of weights in group | |
91 logas(m) = 0.5*group_nweights*log(net.alpha(m)); | |
92 % Compute sum of evalues corresponding to group | |
93 traces(m) = sum(log(safe_evl*net.index(:,m))); | |
94 end | |
95 gamma = sum(gams, 2); | |
96 logev = e - 0.5*sum(traces) + sum(logas) - 0.5*ndata*log(2*pi); | |
97 end | |
98 % Re-estimate beta. | |
99 if isfield(net, 'beta') | |
100 net.beta = 0.5*(net.nout*ndata - gamma)/edata; | |
101 end | |
102 logev = logev + 0.5*ndata*log(beta); | |
103 end | |
104 |