wolffd@0: function [h, hdata] = mlphess(net, x, t, hdata) wolffd@0: %MLPHESS Evaluate the Hessian matrix for a multi-layer perceptron network. wolffd@0: % wolffd@0: % Description wolffd@0: % H = MLPHESS(NET, X, T) takes an MLP network data structure NET, a wolffd@0: % matrix X of input values, and a matrix T of target values and returns wolffd@0: % the full Hessian matrix H corresponding to the second derivatives of wolffd@0: % the negative log posterior distribution, evaluated for the current wolffd@0: % weight and bias values as defined by NET. wolffd@0: % wolffd@0: % [H, HDATA] = MLPHESS(NET, X, T) returns both the Hessian matrix H and wolffd@0: % the contribution HDATA arising from the data dependent term in the wolffd@0: % Hessian. wolffd@0: % wolffd@0: % H = MLPHESS(NET, X, T, HDATA) takes a network data structure NET, a wolffd@0: % matrix X of input values, and a matrix T of target values, together wolffd@0: % with the contribution HDATA arising from the data dependent term in wolffd@0: % the Hessian, and returns the full Hessian matrix H corresponding to wolffd@0: % the second derivatives of the negative log posterior distribution. wolffd@0: % This version saves computation time if HDATA has already been wolffd@0: % evaluated for the current weight and bias values. wolffd@0: % wolffd@0: % See also wolffd@0: % MLP, HESSCHEK, MLPHDOTV, EVIDENCE wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % Check arguments for consistency wolffd@0: errstring = consist(net, 'mlp', x, t); wolffd@0: if ~isempty(errstring); wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: if nargin == 3 wolffd@0: % Data term in Hessian needs to be computed wolffd@0: hdata = datahess(net, x, t); wolffd@0: end wolffd@0: wolffd@0: [h, hdata] = hbayes(net, hdata); wolffd@0: wolffd@0: % Sub-function to compute data part of Hessian wolffd@0: function hdata = datahess(net, x, t) wolffd@0: wolffd@0: hdata = zeros(net.nwts, net.nwts); wolffd@0: wolffd@0: for v = eye(net.nwts); wolffd@0: hdata(find(v),:) = mlphdotv(net, x, t, v); wolffd@0: end wolffd@0: wolffd@0: return