wolffd@0
|
1 function [net, options] = glmtrain(net, options, x, t)
|
wolffd@0
|
2 %GLMTRAIN Specialised training of generalized linear model
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % Description
|
wolffd@0
|
5 % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted
|
wolffd@0
|
6 % least squares (IRLS) algorithm to set the weights in the generalized
|
wolffd@0
|
7 % linear model structure NET. This is a more efficient alternative to
|
wolffd@0
|
8 % using GLMERR and GLMGRAD and a non-linear optimisation routine
|
wolffd@0
|
9 % through NETOPT. Note that for linear outputs, a single pass through
|
wolffd@0
|
10 % the algorithm is all that is required, since the error function is
|
wolffd@0
|
11 % quadratic in the weights. The algorithm also handles scalar ALPHA
|
wolffd@0
|
12 % and BETA terms. If you want to use more complicated priors, you
|
wolffd@0
|
13 % should use general-purpose non-linear optimisation algorithms.
|
wolffd@0
|
14 %
|
wolffd@0
|
15 % For logistic and softmax outputs, general priors can be handled,
|
wolffd@0
|
16 % although this requires the pseudo-inverse of the Hessian, giving up
|
wolffd@0
|
17 % the better conditioning and some of the speed advantage of the normal
|
wolffd@0
|
18 % form equations.
|
wolffd@0
|
19 %
|
wolffd@0
|
20 % The error function value at the final set of weights is returned in
|
wolffd@0
|
21 % OPTIONS(8). Each row of X corresponds to one input vector and each
|
wolffd@0
|
22 % row of T corresponds to one target vector.
|
wolffd@0
|
23 %
|
wolffd@0
|
24 % The optional parameters have the following interpretations.
|
wolffd@0
|
25 %
|
wolffd@0
|
26 % OPTIONS(1) is set to 1 to display error values during training. If
|
wolffd@0
|
27 % OPTIONS(1) is set to 0, then only warning messages are displayed. If
|
wolffd@0
|
28 % OPTIONS(1) is -1, then nothing is displayed.
|
wolffd@0
|
29 %
|
wolffd@0
|
30 % OPTIONS(2) is a measure of the precision required for the value of
|
wolffd@0
|
31 % the weights W at the solution.
|
wolffd@0
|
32 %
|
wolffd@0
|
33 % OPTIONS(3) is a measure of the precision required of the objective
|
wolffd@0
|
34 % function at the solution. Both this and the previous condition must
|
wolffd@0
|
35 % be satisfied for termination.
|
wolffd@0
|
36 %
|
wolffd@0
|
37 % OPTIONS(5) is set to 1 if an approximation to the Hessian (which
|
wolffd@0
|
38 % assumes that all outputs are independent) is used for softmax
|
wolffd@0
|
39 % outputs. With the default value of 0 the exact Hessian (which is more
|
wolffd@0
|
40 % expensive to compute) is used.
|
wolffd@0
|
41 %
|
wolffd@0
|
42 % OPTIONS(14) is the maximum number of iterations for the IRLS
|
wolffd@0
|
43 % algorithm; default 100.
|
wolffd@0
|
44 %
|
wolffd@0
|
45 % See also
|
wolffd@0
|
46 % GLM, GLMERR, GLMGRAD
|
wolffd@0
|
47 %
|
wolffd@0
|
48
|
wolffd@0
|
49 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
50
|
wolffd@0
|
51 % Check arguments for consistency
|
wolffd@0
|
52 errstring = consist(net, 'glm', x, t);
|
wolffd@0
|
53 if ~errstring
|
wolffd@0
|
54 error(errstring);
|
wolffd@0
|
55 end
|
wolffd@0
|
56
|
wolffd@0
|
57 if(~options(14))
|
wolffd@0
|
58 options(14) = 100;
|
wolffd@0
|
59 end
|
wolffd@0
|
60
|
wolffd@0
|
61 display = options(1);
|
wolffd@0
|
62 % Do we need to test for termination?
|
wolffd@0
|
63 test = (options(2) | options(3));
|
wolffd@0
|
64
|
wolffd@0
|
65 ndata = size(x, 1);
|
wolffd@0
|
66 % Add a column of ones for the bias
|
wolffd@0
|
67 inputs = [x ones(ndata, 1)];
|
wolffd@0
|
68
|
wolffd@0
|
69 % Linear outputs are a special case as they can be found in one step
|
wolffd@0
|
70 if strcmp(net.outfn, 'linear')
|
wolffd@0
|
71 if ~isfield(net, 'alpha')
|
wolffd@0
|
72 % Solve for the weights and biases using left matrix divide
|
wolffd@0
|
73 temp = inputs\t;
|
wolffd@0
|
74 elseif size(net.alpha == [1 1])
|
wolffd@0
|
75 if isfield(net, 'beta')
|
wolffd@0
|
76 beta = net.beta;
|
wolffd@0
|
77 else
|
wolffd@0
|
78 beta = 1.0;
|
wolffd@0
|
79 end
|
wolffd@0
|
80 % Use normal form equation
|
wolffd@0
|
81 hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1);
|
wolffd@0
|
82 temp = pinv(hessian)*(beta*(inputs'*t));
|
wolffd@0
|
83 else
|
wolffd@0
|
84 error('Only scalar alpha allowed');
|
wolffd@0
|
85 end
|
wolffd@0
|
86 net.w1 = temp(1:net.nin, :);
|
wolffd@0
|
87 net.b1 = temp(net.nin+1, :);
|
wolffd@0
|
88 % Store error value in options vector
|
wolffd@0
|
89 options(8) = glmerr(net, x, t);
|
wolffd@0
|
90 return;
|
wolffd@0
|
91 end
|
wolffd@0
|
92
|
wolffd@0
|
93 % Otherwise need to use iterative reweighted least squares
|
wolffd@0
|
94 e = ones(1, net.nin+1);
|
wolffd@0
|
95 for n = 1:options(14)
|
wolffd@0
|
96
|
wolffd@0
|
97 switch net.outfn
|
wolffd@0
|
98 case 'logistic'
|
wolffd@0
|
99 if n == 1
|
wolffd@0
|
100 % Initialise model
|
wolffd@0
|
101 p = (t+0.5)/2;
|
wolffd@0
|
102 act = log(p./(1-p));
|
wolffd@0
|
103 wold = glmpak(net);
|
wolffd@0
|
104 end
|
wolffd@0
|
105 link_deriv = p.*(1-p);
|
wolffd@0
|
106 weights = sqrt(link_deriv); % sqrt of weights
|
wolffd@0
|
107 if (min(min(weights)) < eps)
|
wolffd@0
|
108 warning('ill-conditioned weights in glmtrain')
|
wolffd@0
|
109 return
|
wolffd@0
|
110 end
|
wolffd@0
|
111 z = act + (t-p)./link_deriv;
|
wolffd@0
|
112 if ~isfield(net, 'alpha')
|
wolffd@0
|
113 % Treat each output independently with relevant set of weights
|
wolffd@0
|
114 for j = 1:net.nout
|
wolffd@0
|
115 indep = inputs.*(weights(:,j)*e);
|
wolffd@0
|
116 dep = z(:,j).*weights(:,j);
|
wolffd@0
|
117 temp = indep\dep;
|
wolffd@0
|
118 net.w1(:,j) = temp(1:net.nin);
|
wolffd@0
|
119 net.b1(j) = temp(net.nin+1);
|
wolffd@0
|
120 end
|
wolffd@0
|
121 else
|
wolffd@0
|
122 gradient = glmgrad(net, x, t);
|
wolffd@0
|
123 Hessian = glmhess(net, x, t);
|
wolffd@0
|
124 deltaw = -gradient*pinv(Hessian);
|
wolffd@0
|
125 w = wold + deltaw;
|
wolffd@0
|
126 net = glmunpak(net, w);
|
wolffd@0
|
127 end
|
wolffd@0
|
128 [err, edata, eprior, p, act] = glmerr(net, x, t);
|
wolffd@0
|
129 if n == 1
|
wolffd@0
|
130 errold = err;
|
wolffd@0
|
131 wold = netpak(net);
|
wolffd@0
|
132 else
|
wolffd@0
|
133 w = netpak(net);
|
wolffd@0
|
134 end
|
wolffd@0
|
135 case 'softmax'
|
wolffd@0
|
136 if n == 1
|
wolffd@0
|
137 % Initialise model: ensure that row sum of p is one no matter
|
wolffd@0
|
138 % how many classes there are
|
wolffd@0
|
139 p = (t + (1/size(t, 2)))/2;
|
wolffd@0
|
140 act = log(p./(1-p));
|
wolffd@0
|
141 end
|
wolffd@0
|
142 if options(5) == 1 | n == 1
|
wolffd@0
|
143 link_deriv = p.*(1-p);
|
wolffd@0
|
144 weights = sqrt(link_deriv); % sqrt of weights
|
wolffd@0
|
145 if (min(min(weights)) < eps)
|
wolffd@0
|
146 warning('ill-conditioned weights in glmtrain')
|
wolffd@0
|
147 return
|
wolffd@0
|
148 end
|
wolffd@0
|
149 z = act + (t-p)./link_deriv;
|
wolffd@0
|
150 % Treat each output independently with relevant set of weights
|
wolffd@0
|
151 for j = 1:net.nout
|
wolffd@0
|
152 indep = inputs.*(weights(:,j)*e);
|
wolffd@0
|
153 dep = z(:,j).*weights(:,j);
|
wolffd@0
|
154 temp = indep\dep;
|
wolffd@0
|
155 net.w1(:,j) = temp(1:net.nin);
|
wolffd@0
|
156 net.b1(j) = temp(net.nin+1);
|
wolffd@0
|
157 end
|
wolffd@0
|
158 [err, edata, eprior, p, act] = glmerr(net, x, t);
|
wolffd@0
|
159 if n == 1
|
wolffd@0
|
160 errold = err;
|
wolffd@0
|
161 wold = netpak(net);
|
wolffd@0
|
162 else
|
wolffd@0
|
163 w = netpak(net);
|
wolffd@0
|
164 end
|
wolffd@0
|
165 else
|
wolffd@0
|
166 % Exact method of calculation after w first initialised
|
wolffd@0
|
167 % Start by working out Hessian
|
wolffd@0
|
168 Hessian = glmhess(net, x, t);
|
wolffd@0
|
169 gradient = glmgrad(net, x, t);
|
wolffd@0
|
170 % Now compute modification to weights
|
wolffd@0
|
171 deltaw = -gradient*pinv(Hessian);
|
wolffd@0
|
172 w = wold + deltaw;
|
wolffd@0
|
173 net = glmunpak(net, w);
|
wolffd@0
|
174 [err, edata, eprior, p] = glmerr(net, x, t);
|
wolffd@0
|
175 end
|
wolffd@0
|
176
|
wolffd@0
|
177 otherwise
|
wolffd@0
|
178 error(['Unknown activation function ', net.outfn]);
|
wolffd@0
|
179 end
|
wolffd@0
|
180 if options(1)
|
wolffd@0
|
181 fprintf(1, 'Cycle %4d Error %11.6f\n', n, err)
|
wolffd@0
|
182 end
|
wolffd@0
|
183 % Test for termination
|
wolffd@0
|
184 % Terminate if error increases
|
wolffd@0
|
185 if err > errold
|
wolffd@0
|
186 errold = err;
|
wolffd@0
|
187 w = wold;
|
wolffd@0
|
188 options(8) = err;
|
wolffd@0
|
189 fprintf(1, 'Error has increased: terminating\n')
|
wolffd@0
|
190 return;
|
wolffd@0
|
191 end
|
wolffd@0
|
192 if test & n > 1
|
wolffd@0
|
193 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3))
|
wolffd@0
|
194 options(8) = err;
|
wolffd@0
|
195 return;
|
wolffd@0
|
196 else
|
wolffd@0
|
197 errold = err;
|
wolffd@0
|
198 wold = w;
|
wolffd@0
|
199 end
|
wolffd@0
|
200 end
|
wolffd@0
|
201 end
|
wolffd@0
|
202
|
wolffd@0
|
203 options(8) = err;
|
wolffd@0
|
204 if (options(1) >= 0)
|
wolffd@0
|
205 disp(maxitmess);
|
wolffd@0
|
206 end
|