diff toolboxes/FullBNT-1.0.7/netlab3.3/glmtrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/netlab3.3/glmtrain.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,206 @@
+function [net, options] = glmtrain(net, options, x, t)
+%GLMTRAIN Specialised training of generalized linear model
+%
+%	Description
+%	NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted
+%	least squares (IRLS) algorithm to set the weights in the generalized
+%	linear model structure NET.  This is a more efficient alternative to
+%	using GLMERR and GLMGRAD and a non-linear optimisation routine
+%	through NETOPT. Note that for linear outputs, a single pass through
+%	the  algorithm is all that is required, since the error function is
+%	quadratic in the weights.  The algorithm also handles scalar ALPHA
+%	and BETA terms.  If you want to use more complicated priors, you
+%	should use general-purpose non-linear optimisation algorithms.
+%
+%	For logistic and softmax outputs, general priors can be handled,
+%	although this requires the pseudo-inverse of the Hessian, giving up
+%	the better conditioning and some of the speed advantage of the normal
+%	form equations.
+%
+%	The error function value at the final set of weights is returned in
+%	OPTIONS(8). Each row of X corresponds to one input vector and each
+%	row of T corresponds to one target vector.
+%
+%	The optional parameters have the following interpretations.
+%
+%	OPTIONS(1) is set to 1 to display error values during training. If
+%	OPTIONS(1) is set to 0, then only warning messages are displayed.  If
+%	OPTIONS(1) is -1, then nothing is displayed.
+%
+%	OPTIONS(2) is a measure of the precision required for the value of
+%	the weights W at the solution.
+%
+%	OPTIONS(3) is a measure of the precision required of the objective
+%	function at the solution.  Both this and the previous condition must
+%	be satisfied for termination.
+%
+%	OPTIONS(5) is set to 1 if an approximation to the Hessian (which
+%	assumes that all outputs are independent) is used for softmax
+%	outputs. With the default value of 0 the exact Hessian (which is more
+%	expensive to compute) is used.
+%
+%	OPTIONS(14) is the maximum number of iterations for the IRLS
+%	algorithm;  default 100.
+%
+%	See also
+%	GLM, GLMERR, GLMGRAD
+%
+
+%	Copyright (c) Ian T Nabney (1996-2001)
+
+% Check arguments for consistency
+errstring = consist(net, 'glm', x, t);
+if ~errstring
+  error(errstring);
+end
+
+if(~options(14))
+  options(14) = 100;
+end
+
+display = options(1);
+% Do we need to test for termination?
+test = (options(2) | options(3));
+
+ndata = size(x, 1);
+% Add a column of ones for the bias 
+inputs = [x ones(ndata, 1)];
+
+% Linear outputs are a special case as they can be found in one step
+if strcmp(net.outfn, 'linear')
+  if ~isfield(net, 'alpha')
+    % Solve for the weights and biases using left matrix divide
+    temp = inputs\t;
+  elseif size(net.alpha == [1 1])
+    if isfield(net, 'beta')
+      beta = net.beta;
+    else
+      beta = 1.0;
+    end
+    % Use normal form equation
+    hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1);
+    temp = pinv(hessian)*(beta*(inputs'*t));  
+  else
+    error('Only scalar alpha allowed');
+  end
+  net.w1 = temp(1:net.nin, :);
+  net.b1 = temp(net.nin+1, :);
+  % Store error value in options vector
+  options(8) = glmerr(net, x, t);
+  return;
+end
+
+% Otherwise need to use iterative reweighted least squares
+e = ones(1, net.nin+1);
+for n = 1:options(14)
+
+  switch net.outfn
+    case 'logistic'
+      if n == 1
+        % Initialise model
+        p = (t+0.5)/2;
+	act = log(p./(1-p));
+        wold = glmpak(net);
+      end
+      link_deriv = p.*(1-p);
+      weights = sqrt(link_deriv); % sqrt of weights
+      if (min(min(weights)) < eps)
+        warning('ill-conditioned weights in glmtrain')
+        return
+      end
+      z = act + (t-p)./link_deriv;
+      if ~isfield(net, 'alpha')
+         % Treat each output independently with relevant set of weights
+         for j = 1:net.nout
+	    indep = inputs.*(weights(:,j)*e);
+	    dep = z(:,j).*weights(:,j);
+	    temp = indep\dep;
+	    net.w1(:,j) = temp(1:net.nin);
+	    net.b1(j) = temp(net.nin+1);
+         end
+      else
+	 gradient = glmgrad(net, x, t);
+         Hessian = glmhess(net, x, t);
+         deltaw = -gradient*pinv(Hessian);
+         w = wold + deltaw;
+         net = glmunpak(net, w);
+      end
+      [err, edata, eprior, p, act] = glmerr(net, x, t);
+      if n == 1
+        errold = err;
+        wold = netpak(net);
+      else
+        w = netpak(net);
+      end
+    case 'softmax'
+      if n == 1
+        % Initialise model: ensure that row sum of p is one no matter
+	% how many classes there are
+        p = (t + (1/size(t, 2)))/2;
+	act = log(p./(1-p));
+      end
+      if options(5) == 1 | n == 1
+        link_deriv = p.*(1-p);
+        weights = sqrt(link_deriv); % sqrt of weights
+        if (min(min(weights)) < eps)
+          warning('ill-conditioned weights in glmtrain')
+          return
+        end
+        z = act + (t-p)./link_deriv;
+        % Treat each output independently with relevant set of weights
+        for j = 1:net.nout
+          indep = inputs.*(weights(:,j)*e);
+	  dep = z(:,j).*weights(:,j);
+	  temp = indep\dep;
+	  net.w1(:,j) = temp(1:net.nin);
+	  net.b1(j) = temp(net.nin+1);
+        end
+        [err, edata, eprior, p, act] = glmerr(net, x, t);
+        if n == 1
+          errold = err;
+          wold = netpak(net);
+        else
+          w = netpak(net);
+        end
+      else
+	% Exact method of calculation after w first initialised
+	% Start by working out Hessian
+	Hessian = glmhess(net, x, t);
+	gradient = glmgrad(net, x, t);
+	% Now compute modification to weights
+	deltaw = -gradient*pinv(Hessian);
+	w = wold + deltaw;
+	net = glmunpak(net, w);
+	[err, edata, eprior, p] = glmerr(net, x, t);
+    end
+
+    otherwise
+      error(['Unknown activation function ', net.outfn]);
+   end
+   if options(1)
+     fprintf(1, 'Cycle %4d Error %11.6f\n', n, err)
+   end
+   % Test for termination
+   % Terminate if error increases
+   if err >  errold
+     errold = err;
+     w = wold;
+     options(8) = err;
+     fprintf(1, 'Error has increased: terminating\n')
+     return;
+   end
+   if test & n > 1
+     if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3))
+       options(8) = err;
+       return;
+     else
+       errold = err;
+       wold = w;
+     end
+   end
+end
+
+options(8) = err;
+if (options(1) >= 0)
+  disp(maxitmess);
+end