Daniel@0: function [net, gamma, logev] = evidence(net, x, t, num) Daniel@0: %EVIDENCE Re-estimate hyperparameters using evidence approximation. Daniel@0: % Daniel@0: % Description Daniel@0: % [NET] = EVIDENCE(NET, X, T) re-estimates the hyperparameters ALPHA Daniel@0: % and BETA by applying Bayesian re-estimation formulae for NUM Daniel@0: % iterations. The hyperparameter ALPHA can be a simple scalar Daniel@0: % associated with an isotropic prior on the weights, or can be a vector Daniel@0: % in which each component is associated with a group of weights as Daniel@0: % defined by the INDEX matrix in the NET data structure. These more Daniel@0: % complex priors can be set up for an MLP using MLPPRIOR. Initial Daniel@0: % values for the iterative re-estimation are taken from the network Daniel@0: % data structure NET passed as an input argument, while the return Daniel@0: % argument NET contains the re-estimated values. Daniel@0: % Daniel@0: % [NET, GAMMA, LOGEV] = EVIDENCE(NET, X, T, NUM) allows the re- Daniel@0: % estimation formula to be applied for NUM cycles in which the re- Daniel@0: % estimated values for the hyperparameters from each cycle are used to Daniel@0: % re-evaluate the Hessian matrix for the next cycle. The return value Daniel@0: % GAMMA is the number of well-determined parameters and LOGEV is the Daniel@0: % log of the evidence. Daniel@0: % Daniel@0: % See also Daniel@0: % MLPPRIOR, NETGRAD, NETHESS, DEMEV1, DEMARD Daniel@0: % Daniel@0: Daniel@0: % Copyright (c) Ian T Nabney (1996-2001) Daniel@0: Daniel@0: errstring = consist(net, '', x, t); Daniel@0: if ~isempty(errstring) Daniel@0: error(errstring); Daniel@0: end Daniel@0: Daniel@0: ndata = size(x, 1); Daniel@0: if nargin == 3 Daniel@0: num = 1; Daniel@0: end Daniel@0: Daniel@0: % Extract weights from network Daniel@0: w = netpak(net); Daniel@0: Daniel@0: % Evaluate data-dependent contribution to the Hessian matrix. Daniel@0: [h, dh] = nethess(w, net, x, t); Daniel@0: clear h; % To save memory when Hessian is large Daniel@0: if (~isfield(net, 'beta')) Daniel@0: local_beta = 1; Daniel@0: end Daniel@0: Daniel@0: [evec, evl] = eig(dh); Daniel@0: % Now set the negative eigenvalues to zero. Daniel@0: evl = evl.*(evl > 0); Daniel@0: % safe_evl is used to avoid taking log of zero Daniel@0: safe_evl = evl + eps.*(evl <= 0); Daniel@0: Daniel@0: [e, edata, eprior] = neterr(w, net, x, t); Daniel@0: Daniel@0: if size(net.alpha) == [1 1] Daniel@0: % Form vector of eigenvalues Daniel@0: evl = diag(evl); Daniel@0: safe_evl = diag(safe_evl); Daniel@0: else Daniel@0: ngroups = size(net.alpha, 1); Daniel@0: gams = zeros(1, ngroups); Daniel@0: logas = zeros(1, ngroups); Daniel@0: % Reconstruct data hessian with negative eigenvalues set to zero. Daniel@0: dh = evec*evl*evec'; Daniel@0: end Daniel@0: Daniel@0: % Do the re-estimation. Daniel@0: for k = 1 : num Daniel@0: % Re-estimate alpha. Daniel@0: if size(net.alpha) == [1 1] Daniel@0: % Evaluate number of well-determined parameters. Daniel@0: L = evl; Daniel@0: if isfield(net, 'beta') Daniel@0: L = net.beta*L; Daniel@0: end Daniel@0: gamma = sum(L./(L + net.alpha)); Daniel@0: net.alpha = 0.5*gamma/eprior; Daniel@0: % Partially evaluate log evidence: only include unmasked weights Daniel@0: logev = 0.5*length(w)*log(net.alpha); Daniel@0: else Daniel@0: hinv = inv(hbayes(net, dh)); Daniel@0: for m = 1 : ngroups Daniel@0: group_nweights = sum(net.index(:, m)); Daniel@0: gams(m) = group_nweights - ... Daniel@0: net.alpha(m)*sum(diag(hinv).*net.index(:,m)); Daniel@0: net.alpha(m) = real(gams(m)/(2*eprior(m))); Daniel@0: % Weight alphas by number of weights in group Daniel@0: logas(m) = 0.5*group_nweights*log(net.alpha(m)); Daniel@0: end Daniel@0: gamma = sum(gams, 2); Daniel@0: logev = sum(logas); Daniel@0: end Daniel@0: % Re-estimate beta. Daniel@0: if isfield(net, 'beta') Daniel@0: net.beta = 0.5*(net.nout*ndata - gamma)/edata; Daniel@0: logev = logev + 0.5*ndata*log(net.beta) - 0.5*ndata*log(2*pi); Daniel@0: local_beta = net.beta; Daniel@0: end Daniel@0: Daniel@0: % Evaluate new log evidence Daniel@0: e = errbayes(net, edata); Daniel@0: if size(net.alpha) == [1 1] Daniel@0: logev = logev - e - 0.5*sum(log(local_beta*safe_evl+net.alpha)); Daniel@0: else Daniel@0: for m = 1:ngroups Daniel@0: logev = logev - e - ... Daniel@0: 0.5*sum(log(local_beta*(safe_evl*net.index(:, m))+... Daniel@0: net.alpha(m))); Daniel@0: end Daniel@0: end Daniel@0: end Daniel@0: