wolffd@0: function [net, options] = glmtrain(net, options, x, t) wolffd@0: %GLMTRAIN Specialised training of generalized linear model wolffd@0: % wolffd@0: % Description wolffd@0: % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted wolffd@0: % least squares (IRLS) algorithm to set the weights in the generalized wolffd@0: % linear model structure NET. This is a more efficient alternative to wolffd@0: % using GLMERR and GLMGRAD and a non-linear optimisation routine wolffd@0: % through NETOPT. Note that for linear outputs, a single pass through wolffd@0: % the algorithm is all that is required, since the error function is wolffd@0: % quadratic in the weights. The algorithm also handles scalar ALPHA wolffd@0: % and BETA terms. If you want to use more complicated priors, you wolffd@0: % should use general-purpose non-linear optimisation algorithms. wolffd@0: % wolffd@0: % For logistic and softmax outputs, general priors can be handled, wolffd@0: % although this requires the pseudo-inverse of the Hessian, giving up wolffd@0: % the better conditioning and some of the speed advantage of the normal wolffd@0: % form equations. wolffd@0: % wolffd@0: % The error function value at the final set of weights is returned in wolffd@0: % OPTIONS(8). Each row of X corresponds to one input vector and each wolffd@0: % row of T corresponds to one target vector. wolffd@0: % wolffd@0: % The optional parameters have the following interpretations. wolffd@0: % wolffd@0: % OPTIONS(1) is set to 1 to display error values during training. If wolffd@0: % OPTIONS(1) is set to 0, then only warning messages are displayed. If wolffd@0: % OPTIONS(1) is -1, then nothing is displayed. wolffd@0: % wolffd@0: % OPTIONS(2) is a measure of the precision required for the value of wolffd@0: % the weights W at the solution. wolffd@0: % wolffd@0: % OPTIONS(3) is a measure of the precision required of the objective wolffd@0: % function at the solution. Both this and the previous condition must wolffd@0: % be satisfied for termination. wolffd@0: % wolffd@0: % OPTIONS(5) is set to 1 if an approximation to the Hessian (which wolffd@0: % assumes that all outputs are independent) is used for softmax wolffd@0: % outputs. With the default value of 0 the exact Hessian (which is more wolffd@0: % expensive to compute) is used. wolffd@0: % wolffd@0: % OPTIONS(14) is the maximum number of iterations for the IRLS wolffd@0: % algorithm; default 100. wolffd@0: % wolffd@0: % See also wolffd@0: % GLM, GLMERR, GLMGRAD wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % Check arguments for consistency wolffd@0: errstring = consist(net, 'glm', x, t); wolffd@0: if ~errstring wolffd@0: error(errstring); wolffd@0: end wolffd@0: wolffd@0: if(~options(14)) wolffd@0: options(14) = 100; wolffd@0: end wolffd@0: wolffd@0: display = options(1); wolffd@0: % Do we need to test for termination? wolffd@0: test = (options(2) | options(3)); wolffd@0: wolffd@0: ndata = size(x, 1); wolffd@0: % Add a column of ones for the bias wolffd@0: inputs = [x ones(ndata, 1)]; wolffd@0: wolffd@0: % Linear outputs are a special case as they can be found in one step wolffd@0: if strcmp(net.outfn, 'linear') wolffd@0: if ~isfield(net, 'alpha') wolffd@0: % Solve for the weights and biases using left matrix divide wolffd@0: temp = inputs\t; wolffd@0: elseif size(net.alpha == [1 1]) wolffd@0: if isfield(net, 'beta') wolffd@0: beta = net.beta; wolffd@0: else wolffd@0: beta = 1.0; wolffd@0: end wolffd@0: % Use normal form equation wolffd@0: hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1); wolffd@0: temp = pinv(hessian)*(beta*(inputs'*t)); wolffd@0: else wolffd@0: error('Only scalar alpha allowed'); wolffd@0: end wolffd@0: net.w1 = temp(1:net.nin, :); wolffd@0: net.b1 = temp(net.nin+1, :); wolffd@0: % Store error value in options vector wolffd@0: options(8) = glmerr(net, x, t); wolffd@0: return; wolffd@0: end wolffd@0: wolffd@0: % Otherwise need to use iterative reweighted least squares wolffd@0: e = ones(1, net.nin+1); wolffd@0: for n = 1:options(14) wolffd@0: wolffd@0: switch net.outfn wolffd@0: case 'logistic' wolffd@0: if n == 1 wolffd@0: % Initialise model wolffd@0: p = (t+0.5)/2; wolffd@0: act = log(p./(1-p)); wolffd@0: wold = glmpak(net); wolffd@0: end wolffd@0: link_deriv = p.*(1-p); wolffd@0: weights = sqrt(link_deriv); % sqrt of weights wolffd@0: if (min(min(weights)) < eps) wolffd@0: warning('ill-conditioned weights in glmtrain') wolffd@0: return wolffd@0: end wolffd@0: z = act + (t-p)./link_deriv; wolffd@0: if ~isfield(net, 'alpha') wolffd@0: % Treat each output independently with relevant set of weights wolffd@0: for j = 1:net.nout wolffd@0: indep = inputs.*(weights(:,j)*e); wolffd@0: dep = z(:,j).*weights(:,j); wolffd@0: temp = indep\dep; wolffd@0: net.w1(:,j) = temp(1:net.nin); wolffd@0: net.b1(j) = temp(net.nin+1); wolffd@0: end wolffd@0: else wolffd@0: gradient = glmgrad(net, x, t); wolffd@0: Hessian = glmhess(net, x, t); wolffd@0: deltaw = -gradient*pinv(Hessian); wolffd@0: w = wold + deltaw; wolffd@0: net = glmunpak(net, w); wolffd@0: end wolffd@0: [err, edata, eprior, p, act] = glmerr(net, x, t); wolffd@0: if n == 1 wolffd@0: errold = err; wolffd@0: wold = netpak(net); wolffd@0: else wolffd@0: w = netpak(net); wolffd@0: end wolffd@0: case 'softmax' wolffd@0: if n == 1 wolffd@0: % Initialise model: ensure that row sum of p is one no matter wolffd@0: % how many classes there are wolffd@0: p = (t + (1/size(t, 2)))/2; wolffd@0: act = log(p./(1-p)); wolffd@0: end wolffd@0: if options(5) == 1 | n == 1 wolffd@0: link_deriv = p.*(1-p); wolffd@0: weights = sqrt(link_deriv); % sqrt of weights wolffd@0: if (min(min(weights)) < eps) wolffd@0: warning('ill-conditioned weights in glmtrain') wolffd@0: return wolffd@0: end wolffd@0: z = act + (t-p)./link_deriv; wolffd@0: % Treat each output independently with relevant set of weights wolffd@0: for j = 1:net.nout wolffd@0: indep = inputs.*(weights(:,j)*e); wolffd@0: dep = z(:,j).*weights(:,j); wolffd@0: temp = indep\dep; wolffd@0: net.w1(:,j) = temp(1:net.nin); wolffd@0: net.b1(j) = temp(net.nin+1); wolffd@0: end wolffd@0: [err, edata, eprior, p, act] = glmerr(net, x, t); wolffd@0: if n == 1 wolffd@0: errold = err; wolffd@0: wold = netpak(net); wolffd@0: else wolffd@0: w = netpak(net); wolffd@0: end wolffd@0: else wolffd@0: % Exact method of calculation after w first initialised wolffd@0: % Start by working out Hessian wolffd@0: Hessian = glmhess(net, x, t); wolffd@0: gradient = glmgrad(net, x, t); wolffd@0: % Now compute modification to weights wolffd@0: deltaw = -gradient*pinv(Hessian); wolffd@0: w = wold + deltaw; wolffd@0: net = glmunpak(net, w); wolffd@0: [err, edata, eprior, p] = glmerr(net, x, t); wolffd@0: end wolffd@0: wolffd@0: otherwise wolffd@0: error(['Unknown activation function ', net.outfn]); wolffd@0: end wolffd@0: if options(1) wolffd@0: fprintf(1, 'Cycle %4d Error %11.6f\n', n, err) wolffd@0: end wolffd@0: % Test for termination wolffd@0: % Terminate if error increases wolffd@0: if err > errold wolffd@0: errold = err; wolffd@0: w = wold; wolffd@0: options(8) = err; wolffd@0: fprintf(1, 'Error has increased: terminating\n') wolffd@0: return; wolffd@0: end wolffd@0: if test & n > 1 wolffd@0: if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3)) wolffd@0: options(8) = err; wolffd@0: return; wolffd@0: else wolffd@0: errold = err; wolffd@0: wold = w; wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: options(8) = err; wolffd@0: if (options(1) >= 0) wolffd@0: disp(maxitmess); wolffd@0: end