annotate toolboxes/FullBNT-1.0.7/netlab3.3/glmtrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [net, options] = glmtrain(net, options, x, t)
wolffd@0 2 %GLMTRAIN Specialised training of generalized linear model
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted
wolffd@0 6 % least squares (IRLS) algorithm to set the weights in the generalized
wolffd@0 7 % linear model structure NET. This is a more efficient alternative to
wolffd@0 8 % using GLMERR and GLMGRAD and a non-linear optimisation routine
wolffd@0 9 % through NETOPT. Note that for linear outputs, a single pass through
wolffd@0 10 % the algorithm is all that is required, since the error function is
wolffd@0 11 % quadratic in the weights. The algorithm also handles scalar ALPHA
wolffd@0 12 % and BETA terms. If you want to use more complicated priors, you
wolffd@0 13 % should use general-purpose non-linear optimisation algorithms.
wolffd@0 14 %
wolffd@0 15 % For logistic and softmax outputs, general priors can be handled,
wolffd@0 16 % although this requires the pseudo-inverse of the Hessian, giving up
wolffd@0 17 % the better conditioning and some of the speed advantage of the normal
wolffd@0 18 % form equations.
wolffd@0 19 %
wolffd@0 20 % The error function value at the final set of weights is returned in
wolffd@0 21 % OPTIONS(8). Each row of X corresponds to one input vector and each
wolffd@0 22 % row of T corresponds to one target vector.
wolffd@0 23 %
wolffd@0 24 % The optional parameters have the following interpretations.
wolffd@0 25 %
wolffd@0 26 % OPTIONS(1) is set to 1 to display error values during training. If
wolffd@0 27 % OPTIONS(1) is set to 0, then only warning messages are displayed. If
wolffd@0 28 % OPTIONS(1) is -1, then nothing is displayed.
wolffd@0 29 %
wolffd@0 30 % OPTIONS(2) is a measure of the precision required for the value of
wolffd@0 31 % the weights W at the solution.
wolffd@0 32 %
wolffd@0 33 % OPTIONS(3) is a measure of the precision required of the objective
wolffd@0 34 % function at the solution. Both this and the previous condition must
wolffd@0 35 % be satisfied for termination.
wolffd@0 36 %
wolffd@0 37 % OPTIONS(5) is set to 1 if an approximation to the Hessian (which
wolffd@0 38 % assumes that all outputs are independent) is used for softmax
wolffd@0 39 % outputs. With the default value of 0 the exact Hessian (which is more
wolffd@0 40 % expensive to compute) is used.
wolffd@0 41 %
wolffd@0 42 % OPTIONS(14) is the maximum number of iterations for the IRLS
wolffd@0 43 % algorithm; default 100.
wolffd@0 44 %
wolffd@0 45 % See also
wolffd@0 46 % GLM, GLMERR, GLMGRAD
wolffd@0 47 %
wolffd@0 48
wolffd@0 49 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 50
wolffd@0 51 % Check arguments for consistency
wolffd@0 52 errstring = consist(net, 'glm', x, t);
wolffd@0 53 if ~errstring
wolffd@0 54 error(errstring);
wolffd@0 55 end
wolffd@0 56
wolffd@0 57 if(~options(14))
wolffd@0 58 options(14) = 100;
wolffd@0 59 end
wolffd@0 60
wolffd@0 61 display = options(1);
wolffd@0 62 % Do we need to test for termination?
wolffd@0 63 test = (options(2) | options(3));
wolffd@0 64
wolffd@0 65 ndata = size(x, 1);
wolffd@0 66 % Add a column of ones for the bias
wolffd@0 67 inputs = [x ones(ndata, 1)];
wolffd@0 68
wolffd@0 69 % Linear outputs are a special case as they can be found in one step
wolffd@0 70 if strcmp(net.outfn, 'linear')
wolffd@0 71 if ~isfield(net, 'alpha')
wolffd@0 72 % Solve for the weights and biases using left matrix divide
wolffd@0 73 temp = inputs\t;
wolffd@0 74 elseif size(net.alpha == [1 1])
wolffd@0 75 if isfield(net, 'beta')
wolffd@0 76 beta = net.beta;
wolffd@0 77 else
wolffd@0 78 beta = 1.0;
wolffd@0 79 end
wolffd@0 80 % Use normal form equation
wolffd@0 81 hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1);
wolffd@0 82 temp = pinv(hessian)*(beta*(inputs'*t));
wolffd@0 83 else
wolffd@0 84 error('Only scalar alpha allowed');
wolffd@0 85 end
wolffd@0 86 net.w1 = temp(1:net.nin, :);
wolffd@0 87 net.b1 = temp(net.nin+1, :);
wolffd@0 88 % Store error value in options vector
wolffd@0 89 options(8) = glmerr(net, x, t);
wolffd@0 90 return;
wolffd@0 91 end
wolffd@0 92
wolffd@0 93 % Otherwise need to use iterative reweighted least squares
wolffd@0 94 e = ones(1, net.nin+1);
wolffd@0 95 for n = 1:options(14)
wolffd@0 96
wolffd@0 97 switch net.outfn
wolffd@0 98 case 'logistic'
wolffd@0 99 if n == 1
wolffd@0 100 % Initialise model
wolffd@0 101 p = (t+0.5)/2;
wolffd@0 102 act = log(p./(1-p));
wolffd@0 103 wold = glmpak(net);
wolffd@0 104 end
wolffd@0 105 link_deriv = p.*(1-p);
wolffd@0 106 weights = sqrt(link_deriv); % sqrt of weights
wolffd@0 107 if (min(min(weights)) < eps)
wolffd@0 108 warning('ill-conditioned weights in glmtrain')
wolffd@0 109 return
wolffd@0 110 end
wolffd@0 111 z = act + (t-p)./link_deriv;
wolffd@0 112 if ~isfield(net, 'alpha')
wolffd@0 113 % Treat each output independently with relevant set of weights
wolffd@0 114 for j = 1:net.nout
wolffd@0 115 indep = inputs.*(weights(:,j)*e);
wolffd@0 116 dep = z(:,j).*weights(:,j);
wolffd@0 117 temp = indep\dep;
wolffd@0 118 net.w1(:,j) = temp(1:net.nin);
wolffd@0 119 net.b1(j) = temp(net.nin+1);
wolffd@0 120 end
wolffd@0 121 else
wolffd@0 122 gradient = glmgrad(net, x, t);
wolffd@0 123 Hessian = glmhess(net, x, t);
wolffd@0 124 deltaw = -gradient*pinv(Hessian);
wolffd@0 125 w = wold + deltaw;
wolffd@0 126 net = glmunpak(net, w);
wolffd@0 127 end
wolffd@0 128 [err, edata, eprior, p, act] = glmerr(net, x, t);
wolffd@0 129 if n == 1
wolffd@0 130 errold = err;
wolffd@0 131 wold = netpak(net);
wolffd@0 132 else
wolffd@0 133 w = netpak(net);
wolffd@0 134 end
wolffd@0 135 case 'softmax'
wolffd@0 136 if n == 1
wolffd@0 137 % Initialise model: ensure that row sum of p is one no matter
wolffd@0 138 % how many classes there are
wolffd@0 139 p = (t + (1/size(t, 2)))/2;
wolffd@0 140 act = log(p./(1-p));
wolffd@0 141 end
wolffd@0 142 if options(5) == 1 | n == 1
wolffd@0 143 link_deriv = p.*(1-p);
wolffd@0 144 weights = sqrt(link_deriv); % sqrt of weights
wolffd@0 145 if (min(min(weights)) < eps)
wolffd@0 146 warning('ill-conditioned weights in glmtrain')
wolffd@0 147 return
wolffd@0 148 end
wolffd@0 149 z = act + (t-p)./link_deriv;
wolffd@0 150 % Treat each output independently with relevant set of weights
wolffd@0 151 for j = 1:net.nout
wolffd@0 152 indep = inputs.*(weights(:,j)*e);
wolffd@0 153 dep = z(:,j).*weights(:,j);
wolffd@0 154 temp = indep\dep;
wolffd@0 155 net.w1(:,j) = temp(1:net.nin);
wolffd@0 156 net.b1(j) = temp(net.nin+1);
wolffd@0 157 end
wolffd@0 158 [err, edata, eprior, p, act] = glmerr(net, x, t);
wolffd@0 159 if n == 1
wolffd@0 160 errold = err;
wolffd@0 161 wold = netpak(net);
wolffd@0 162 else
wolffd@0 163 w = netpak(net);
wolffd@0 164 end
wolffd@0 165 else
wolffd@0 166 % Exact method of calculation after w first initialised
wolffd@0 167 % Start by working out Hessian
wolffd@0 168 Hessian = glmhess(net, x, t);
wolffd@0 169 gradient = glmgrad(net, x, t);
wolffd@0 170 % Now compute modification to weights
wolffd@0 171 deltaw = -gradient*pinv(Hessian);
wolffd@0 172 w = wold + deltaw;
wolffd@0 173 net = glmunpak(net, w);
wolffd@0 174 [err, edata, eprior, p] = glmerr(net, x, t);
wolffd@0 175 end
wolffd@0 176
wolffd@0 177 otherwise
wolffd@0 178 error(['Unknown activation function ', net.outfn]);
wolffd@0 179 end
wolffd@0 180 if options(1)
wolffd@0 181 fprintf(1, 'Cycle %4d Error %11.6f\n', n, err)
wolffd@0 182 end
wolffd@0 183 % Test for termination
wolffd@0 184 % Terminate if error increases
wolffd@0 185 if err > errold
wolffd@0 186 errold = err;
wolffd@0 187 w = wold;
wolffd@0 188 options(8) = err;
wolffd@0 189 fprintf(1, 'Error has increased: terminating\n')
wolffd@0 190 return;
wolffd@0 191 end
wolffd@0 192 if test & n > 1
wolffd@0 193 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3))
wolffd@0 194 options(8) = err;
wolffd@0 195 return;
wolffd@0 196 else
wolffd@0 197 errold = err;
wolffd@0 198 wold = w;
wolffd@0 199 end
wolffd@0 200 end
wolffd@0 201 end
wolffd@0 202
wolffd@0 203 options(8) = err;
wolffd@0 204 if (options(1) >= 0)
wolffd@0 205 disp(maxitmess);
wolffd@0 206 end