wolffd@0: function [net, gamma, logev] = evidence(net, x, t, num) wolffd@0: %EVIDENCE Re-estimate hyperparameters using evidence approximation. wolffd@0: % wolffd@0: % Description wolffd@0: % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA wolffd@0: % and BETA by applying Bayesian re-estimation formulae for NUM wolffd@0: % iterations. The hyperparameter ALPHA can be a simple scalar wolffd@0: % associated with an isotropic prior on the weights, or can be a vector wolffd@0: % in which each component is associated with a group of weights as wolffd@0: % defined by the INDEX matrix in the NET data structure. These more wolffd@0: % complex priors can be set up for an MLP using MLPPRIOR. Initial wolffd@0: % values for the iterative re-estimation are taken from the network wolffd@0: % data structure NET passed as an input argument, while the return wolffd@0: % argument NET contains the re-estimated values. wolffd@0: % wolffd@0: % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re- wolffd@0: % estimation formula to be applied for NUM cycles in which the re- wolffd@0: % estimated values for the hyperparameters from each cycle are used to wolffd@0: % re-evaluate the Hessian matrix for the next cycle. The return value wolffd@0: % GAMMA is the number of well-determined parameters and LOGEV is the wolffd@0: % log of the evidence. wolffd@0: % wolffd@0: % See also wolffd@0: % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: errstring = consist(net, '', x, t); wolffd@0: if ~isempty(errstring) wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: ndata = size(x, 1); wolffd@0: if nargin == 3 wolffd@0: num = 1; wolffd@0: end wolffd@0: wolffd@0: % Extract weights from network wolffd@0: w = netpak(net); wolffd@0: wolffd@0: % Evaluate data-dependent contribution to the Hessian matrix. wolffd@0: [h, dh] = nethess(w, net, x, t); wolffd@0: clear h; % To save memory when Hessian is large wolffd@0: if (~isfield(net, 'beta')) wolffd@0: local_beta = 1; wolffd@0: end wolffd@0: wolffd@0: [evec, evl] = eig(dh); wolffd@0: % Now set the negative eigenvalues to zero. wolffd@0: evl = evl.*(evl > 0); wolffd@0: % safe_evl is used to avoid taking log of zero wolffd@0: safe_evl = evl + eps.*(evl <= 0); wolffd@0: wolffd@0: [e, edata, eprior] = neterr(w, net, x, t); wolffd@0: wolffd@0: if size(net.alpha) == [1 1] wolffd@0: % Form vector of eigenvalues wolffd@0: evl = diag(evl); wolffd@0: safe_evl = diag(safe_evl); wolffd@0: else wolffd@0: ngroups = size(net.alpha, 1); wolffd@0: gams = zeros(1, ngroups); wolffd@0: logas = zeros(1, ngroups); wolffd@0: % Reconstruct data hessian with negative eigenvalues set to zero. wolffd@0: dh = evec*evl*evec'; wolffd@0: end wolffd@0: wolffd@0: % Do the re-estimation. wolffd@0: for k = 1 : num wolffd@0: % Re-estimate alpha. wolffd@0: if size(net.alpha) == [1 1] wolffd@0: % Evaluate number of well-determined parameters. wolffd@0: L = evl; wolffd@0: if isfield(net, 'beta') wolffd@0: L = net.beta*L; wolffd@0: end wolffd@0: gamma = sum(L./(L + net.alpha)); wolffd@0: net.alpha = 0.5*gamma/eprior; wolffd@0: % Partially evaluate log evidence: only include unmasked weights wolffd@0: logev = 0.5*length(w)*log(net.alpha); wolffd@0: else wolffd@0: hinv = inv(hbayes(net, dh)); wolffd@0: for m = 1 : ngroups wolffd@0: group_nweights = sum(net.index(:, m)); wolffd@0: gams(m) = group_nweights - ... wolffd@0: net.alpha(m)*sum(diag(hinv).*net.index(:,m)); wolffd@0: net.alpha(m) = real(gams(m)/(2*eprior(m))); wolffd@0: % Weight alphas by number of weights in group wolffd@0: logas(m) = 0.5*group_nweights*log(net.alpha(m)); wolffd@0: end wolffd@0: gamma = sum(gams, 2); wolffd@0: logev = sum(logas); wolffd@0: end wolffd@0: % Re-estimate beta. wolffd@0: if isfield(net, 'beta') wolffd@0: net.beta = 0.5*(net.nout*ndata - gamma)/edata; wolffd@0: logev = logev + 0.5*ndata*log(net.beta) - 0.5*ndata*log(2*pi); wolffd@0: local_beta = net.beta; wolffd@0: end wolffd@0: wolffd@0: % Evaluate new log evidence wolffd@0: e = errbayes(net, edata); wolffd@0: if size(net.alpha) == [1 1] wolffd@0: logev = logev - e - 0.5*sum(log(local_beta*safe_evl+net.alpha)); wolffd@0: else wolffd@0: for m = 1:ngroups wolffd@0: logev = logev - e - ... wolffd@0: 0.5*sum(log(local_beta*(safe_evl*net.index(:, m))+... wolffd@0: net.alpha(m))); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: