annotate toolboxes/FullBNT-1.0.7/netlabKPM/evidence_weighted.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [net, gamma, logev] = evidence_weighted(net, x, t, eso_w, num)
wolffd@0 2 %EVIDENCE Re-estimate hyperparameters using evidence approximation.
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA
wolffd@0 6 % and BETA by applying Bayesian re-estimation formulae for NUM
wolffd@0 7 % iterations. The hyperparameter ALPHA can be a simple scalar
wolffd@0 8 % associated with an isotropic prior on the weights, or can be a vector
wolffd@0 9 % in which each component is associated with a group of weights as
wolffd@0 10 % defined by the INDEX matrix in the NET data structure. These more
wolffd@0 11 % complex priors can be set up for an MLP using MLPPRIOR. Initial
wolffd@0 12 % values for the iterative re-estimation are taken from the network
wolffd@0 13 % data structure NET passed as an input argument, while the return
wolffd@0 14 % argument NET contains the re-estimated values.
wolffd@0 15 %
wolffd@0 16 % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re-
wolffd@0 17 % estimation formula to be applied for NUM cycles in which the re-
wolffd@0 18 % estimated values for the hyperparameters from each cycle are used to
wolffd@0 19 % re-evaluate the Hessian matrix for the next cycle. The return value
wolffd@0 20 % GAMMA is the number of well-determined parameters and LOGEV is the
wolffd@0 21 % log of the evidence.
wolffd@0 22 %
wolffd@0 23 % See also
wolffd@0 24 % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD
wolffd@0 25 %
wolffd@0 26
wolffd@0 27 % Copyright (c) Ian T Nabney (1996-9)
wolffd@0 28
wolffd@0 29 errstring = consist(net, '', x, t);
wolffd@0 30 if ~isempty(errstring)
wolffd@0 31 error(errstring);
wolffd@0 32 end
wolffd@0 33
wolffd@0 34 ndata = size(x, 1);
wolffd@0 35 if nargin == 4
wolffd@0 36 num = 1;
wolffd@0 37 end
wolffd@0 38
wolffd@0 39 if isfield(net,'beta')
wolffd@0 40 beta = net.beta;
wolffd@0 41 else
wolffd@0 42 beta = 1;
wolffd@0 43 end;
wolffd@0 44
wolffd@0 45 % Extract weights from network
wolffd@0 46 pakstr = [net.type, 'pak'];
wolffd@0 47 w = feval(pakstr, net);
wolffd@0 48
wolffd@0 49 % Evaluate data-dependent contribution to the Hessian matrix.
wolffd@0 50 [h, dh] = nethess_weighted(w, net, x, t, eso_w);
wolffd@0 51
wolffd@0 52 % Now set the negative eigenvalues to zero.
wolffd@0 53 [evec, evl] = eig(dh);
wolffd@0 54 evl = evl.*(evl > 0);
wolffd@0 55 % safe_evl is used to avoid taking log of zero
wolffd@0 56 safe_evl = evl + eps.*(evl <= 0);
wolffd@0 57
wolffd@0 58 % Do the re-estimation.
wolffd@0 59 for k = 1 : num
wolffd@0 60 [e, edata, eprior] = neterr_weighted(w, net, x, t, eso_w);
wolffd@0 61 h = nethess_weighted(w, net, x, t, eso_w, dh);
wolffd@0 62 % Re-estimate alpha.
wolffd@0 63 if size(net.alpha) == [1 1]
wolffd@0 64 % Evaluate number of well-determined parameters.
wolffd@0 65 if k == 1
wolffd@0 66 % Form vector of eigenvalues
wolffd@0 67 evl = diag(evl);
wolffd@0 68 safe_evl = diag(safe_evl);
wolffd@0 69 end
wolffd@0 70 B = beta*evl;
wolffd@0 71 gamma = sum(B./(B + net.alpha));
wolffd@0 72 net.alpha = 0.5*gamma/eprior;
wolffd@0 73
wolffd@0 74 % Partially evaluate log evidence
wolffd@0 75 logev = e - 0.5*sum(log(safe_evl)) + 0.5*net.nwts*log(net.alpha) - ...
wolffd@0 76 0.5*ndata*log(2*pi);
wolffd@0 77 else
wolffd@0 78 ngroups = size(net.alpha, 1);
wolffd@0 79 gams = zeros(1, ngroups);
wolffd@0 80 logas = zeros(1, ngroups);
wolffd@0 81 traces = zeros(1, ngroups);
wolffd@0 82 % Reconstruct data hessian with negative eigenvalues set to zero.
wolffd@0 83 dh = evec*evl*evec';
wolffd@0 84 hinv = inv(nethess_weighted(w, net, x, t, eso_w, dh));
wolffd@0 85 for m = 1 : ngroups
wolffd@0 86 group_nweights = sum(net.index(:, m));
wolffd@0 87 gams(m) = group_nweights - ...
wolffd@0 88 net.alpha(m)*sum(diag(hinv).*net.index(:,m));
wolffd@0 89 net.alpha(m) = real(gams(m)/(2*eprior(m)));
wolffd@0 90 % Weight alphas by number of weights in group
wolffd@0 91 logas(m) = 0.5*group_nweights*log(net.alpha(m));
wolffd@0 92 % Compute sum of evalues corresponding to group
wolffd@0 93 traces(m) = sum(log(safe_evl*net.index(:,m)));
wolffd@0 94 end
wolffd@0 95 gamma = sum(gams, 2);
wolffd@0 96 logev = e - 0.5*sum(traces) + sum(logas) - 0.5*ndata*log(2*pi);
wolffd@0 97 end
wolffd@0 98 % Re-estimate beta.
wolffd@0 99 if isfield(net, 'beta')
wolffd@0 100 net.beta = 0.5*(net.nout*ndata - gamma)/edata;
wolffd@0 101 end
wolffd@0 102 logev = logev + 0.5*ndata*log(beta);
wolffd@0 103 end
wolffd@0 104