Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/glmtrain.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [net, options] = glmtrain(net, options, x, t) | |
2 %GLMTRAIN Specialised training of generalized linear model | |
3 % | |
4 % Description | |
5 % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted | |
6 % least squares (IRLS) algorithm to set the weights in the generalized | |
7 % linear model structure NET. This is a more efficient alternative to | |
8 % using GLMERR and GLMGRAD and a non-linear optimisation routine | |
9 % through NETOPT. Note that for linear outputs, a single pass through | |
10 % the algorithm is all that is required, since the error function is | |
11 % quadratic in the weights. The algorithm also handles scalar ALPHA | |
12 % and BETA terms. If you want to use more complicated priors, you | |
13 % should use general-purpose non-linear optimisation algorithms. | |
14 % | |
15 % For logistic and softmax outputs, general priors can be handled, | |
16 % although this requires the pseudo-inverse of the Hessian, giving up | |
17 % the better conditioning and some of the speed advantage of the normal | |
18 % form equations. | |
19 % | |
20 % The error function value at the final set of weights is returned in | |
21 % OPTIONS(8). Each row of X corresponds to one input vector and each | |
22 % row of T corresponds to one target vector. | |
23 % | |
24 % The optional parameters have the following interpretations. | |
25 % | |
26 % OPTIONS(1) is set to 1 to display error values during training. If | |
27 % OPTIONS(1) is set to 0, then only warning messages are displayed. If | |
28 % OPTIONS(1) is -1, then nothing is displayed. | |
29 % | |
30 % OPTIONS(2) is a measure of the precision required for the value of | |
31 % the weights W at the solution. | |
32 % | |
33 % OPTIONS(3) is a measure of the precision required of the objective | |
34 % function at the solution. Both this and the previous condition must | |
35 % be satisfied for termination. | |
36 % | |
37 % OPTIONS(5) is set to 1 if an approximation to the Hessian (which | |
38 % assumes that all outputs are independent) is used for softmax | |
39 % outputs. With the default value of 0 the exact Hessian (which is more | |
40 % expensive to compute) is used. | |
41 % | |
42 % OPTIONS(14) is the maximum number of iterations for the IRLS | |
43 % algorithm; default 100. | |
44 % | |
45 % See also | |
46 % GLM, GLMERR, GLMGRAD | |
47 % | |
48 | |
49 % Copyright (c) Ian T Nabney (1996-2001) | |
50 | |
51 % Check arguments for consistency | |
52 errstring = consist(net, 'glm', x, t); | |
53 if ~errstring | |
54 error(errstring); | |
55 end | |
56 | |
57 if(~options(14)) | |
58 options(14) = 100; | |
59 end | |
60 | |
61 display = options(1); | |
62 % Do we need to test for termination? | |
63 test = (options(2) | options(3)); | |
64 | |
65 ndata = size(x, 1); | |
66 % Add a column of ones for the bias | |
67 inputs = [x ones(ndata, 1)]; | |
68 | |
69 % Linear outputs are a special case as they can be found in one step | |
70 if strcmp(net.outfn, 'linear') | |
71 if ~isfield(net, 'alpha') | |
72 % Solve for the weights and biases using left matrix divide | |
73 temp = inputs\t; | |
74 elseif size(net.alpha == [1 1]) | |
75 if isfield(net, 'beta') | |
76 beta = net.beta; | |
77 else | |
78 beta = 1.0; | |
79 end | |
80 % Use normal form equation | |
81 hessian = beta*(inputs'*inputs) + net.alpha*eye(net.nin+1); | |
82 temp = pinv(hessian)*(beta*(inputs'*t)); | |
83 else | |
84 error('Only scalar alpha allowed'); | |
85 end | |
86 net.w1 = temp(1:net.nin, :); | |
87 net.b1 = temp(net.nin+1, :); | |
88 % Store error value in options vector | |
89 options(8) = glmerr(net, x, t); | |
90 return; | |
91 end | |
92 | |
93 % Otherwise need to use iterative reweighted least squares | |
94 e = ones(1, net.nin+1); | |
95 for n = 1:options(14) | |
96 | |
97 switch net.outfn | |
98 case 'logistic' | |
99 if n == 1 | |
100 % Initialise model | |
101 p = (t+0.5)/2; | |
102 act = log(p./(1-p)); | |
103 wold = glmpak(net); | |
104 end | |
105 link_deriv = p.*(1-p); | |
106 weights = sqrt(link_deriv); % sqrt of weights | |
107 if (min(min(weights)) < eps) | |
108 warning('ill-conditioned weights in glmtrain') | |
109 return | |
110 end | |
111 z = act + (t-p)./link_deriv; | |
112 if ~isfield(net, 'alpha') | |
113 % Treat each output independently with relevant set of weights | |
114 for j = 1:net.nout | |
115 indep = inputs.*(weights(:,j)*e); | |
116 dep = z(:,j).*weights(:,j); | |
117 temp = indep\dep; | |
118 net.w1(:,j) = temp(1:net.nin); | |
119 net.b1(j) = temp(net.nin+1); | |
120 end | |
121 else | |
122 gradient = glmgrad(net, x, t); | |
123 Hessian = glmhess(net, x, t); | |
124 deltaw = -gradient*pinv(Hessian); | |
125 w = wold + deltaw; | |
126 net = glmunpak(net, w); | |
127 end | |
128 [err, edata, eprior, p, act] = glmerr(net, x, t); | |
129 if n == 1 | |
130 errold = err; | |
131 wold = netpak(net); | |
132 else | |
133 w = netpak(net); | |
134 end | |
135 case 'softmax' | |
136 if n == 1 | |
137 % Initialise model: ensure that row sum of p is one no matter | |
138 % how many classes there are | |
139 p = (t + (1/size(t, 2)))/2; | |
140 act = log(p./(1-p)); | |
141 end | |
142 if options(5) == 1 | n == 1 | |
143 link_deriv = p.*(1-p); | |
144 weights = sqrt(link_deriv); % sqrt of weights | |
145 if (min(min(weights)) < eps) | |
146 warning('ill-conditioned weights in glmtrain') | |
147 return | |
148 end | |
149 z = act + (t-p)./link_deriv; | |
150 % Treat each output independently with relevant set of weights | |
151 for j = 1:net.nout | |
152 indep = inputs.*(weights(:,j)*e); | |
153 dep = z(:,j).*weights(:,j); | |
154 temp = indep\dep; | |
155 net.w1(:,j) = temp(1:net.nin); | |
156 net.b1(j) = temp(net.nin+1); | |
157 end | |
158 [err, edata, eprior, p, act] = glmerr(net, x, t); | |
159 if n == 1 | |
160 errold = err; | |
161 wold = netpak(net); | |
162 else | |
163 w = netpak(net); | |
164 end | |
165 else | |
166 % Exact method of calculation after w first initialised | |
167 % Start by working out Hessian | |
168 Hessian = glmhess(net, x, t); | |
169 gradient = glmgrad(net, x, t); | |
170 % Now compute modification to weights | |
171 deltaw = -gradient*pinv(Hessian); | |
172 w = wold + deltaw; | |
173 net = glmunpak(net, w); | |
174 [err, edata, eprior, p] = glmerr(net, x, t); | |
175 end | |
176 | |
177 otherwise | |
178 error(['Unknown activation function ', net.outfn]); | |
179 end | |
180 if options(1) | |
181 fprintf(1, 'Cycle %4d Error %11.6f\n', n, err) | |
182 end | |
183 % Test for termination | |
184 % Terminate if error increases | |
185 if err > errold | |
186 errold = err; | |
187 w = wold; | |
188 options(8) = err; | |
189 fprintf(1, 'Error has increased: terminating\n') | |
190 return; | |
191 end | |
192 if test & n > 1 | |
193 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3)) | |
194 options(8) = err; | |
195 return; | |
196 else | |
197 errold = err; | |
198 wold = w; | |
199 end | |
200 end | |
201 end | |
202 | |
203 options(8) = err; | |
204 if (options(1) >= 0) | |
205 disp(maxitmess); | |
206 end |