comparison toolboxes/FullBNT-1.0.7/netlabKPM/glmtrain_weighted.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [net, options] = glmtrain_weighted(net, options, x, t, eso_w, alfa)
2 %GLMTRAIN Specialised training of generalized linear model
3 %
4 % Description
5 % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted
6 % least squares (IRLS) algorithm to set the weights in the generalized
7 % linear model structure NET. This is a more efficient alternative to
8 % using GLMERR and GLMGRAD and a non-linear optimisation routine
9 % through NETOPT. Note that for linear outputs, a single pass through
10 % the algorithm is all that is required, since the error function is
11 % quadratic in the weights. The error function value at the final set
12 % of weights is returned in OPTIONS(8). Each row of X corresponds to
13 % one input vector and each row of T corresponds to one target vector.
14 %
15 % The optional parameters have the following interpretations.
16 %
17 % OPTIONS(1) is set to 1 to display error values during training. If
18 % OPTIONS(1) is set to 0, then only warning messages are displayed. If
19 % OPTIONS(1) is -1, then nothing is displayed.
20 %
21 % OPTIONS(2) is a measure of the precision required for the value of
22 % the weights W at the solution.
23 %
24 % OPTIONS(3) is a measure of the precision required of the objective
25 % function at the solution. Both this and the previous condition must
26 % be satisfied for termination.
27 %
28 % OPTIONS(5) is set to 1 if an approximation to the Hessian (which
29 % assumes that all outputs are independent) is used for softmax
30 % outputs. With the default value of 0 the exact Hessian (which is more
31 % expensive to compute) is used.
32 %
33 % OPTIONS(14) is the maximum number of iterations for the IRLS
34 % algorithm; default 100.
35 %
36 % See also
37 % GLM, GLMERR, GLMGRAD
38 %
39
40 % Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997)
41
42 % Check arguments for consistency
43 errstring = consist(net, 'glm', x, t);
44 if ~errstring
45 error(errstring);
46 end
47
48 if(~options(14))
49 options(14) = 100;
50 end
51
52 display = options(1);
53
54 test = (options(2) | options(3)); % Do we need to test for termination?
55
56 ndata = size(x, 1);
57
58 inputs = [x ones(ndata, 1)]; % Add a column of ones for the bias
59
60 % Use weighted iterative reweighted least squares (WIRLS)
61 e = ones(1, net.nin+1);
62 for n = 1:options(14)
63
64 %switch net.actfn
65 switch net.outfn
66 case 'softmax'
67 if n == 1
68 p = (t + (1/size(t, 2)))/2; % Initialise model: ensure that row sum of p is one no matter
69 act = log(p./(1-p)); % how many classes there are
70 end
71 if options(5) == 1 | n == 1
72 link_deriv = p.*(1-p);
73 weights = sqrt(link_deriv); % sqrt of weights
74 if (min(min(weights)) < eps)
75 fprintf(1, 'Warning: ill-conditioned weights in glmtrain\n')
76 return
77 end
78 z = act + (t-p)./link_deriv;
79 % Treat each output independently with relevant set of weights
80 for j = 1:net.nout
81 indep = inputs.*(weights(:,j)*e);
82 dep = z(:,j).*weights(:,j);
83 temp = indep\dep;
84 net.w1(:,j) = temp(1:net.nin);
85 net.b1(j) = temp(net.nin+1);
86 end
87 [err, edata, eprior, p, act] = glmerr_weighted(net, x, t, eso_w);
88 if n == 1
89 errold = err;
90 wold = netpak(net);
91 else
92 w = netpak(net);
93 end
94 else
95 % Exact method of calculation after w first initialised
96 % Start by working out Hessian
97 Hessian = glmhess_weighted(net, x, t, eso_w);
98 temp = p-t;
99 for m=1:ndata,
100 temp(m,:)=eso_w(m,1)*temp(m,:);
101 end
102 gw1 = x'*(temp);
103 gb1 = sum(temp, 1);
104 gradient = [gw1(:)', gb1];
105 % Now compute modification to weights
106 deltaw = -gradient*pinv(Hessian);
107 w = wold + alfa*deltaw;
108 net = glmunpak(net, w);
109 [err, edata, eprior, p] = glmerr_weighted(net, x, t, eso_w);
110 end
111 otherwise
112 error(['Unknown activation function ', net.actfn]);
113 end % switch' end
114
115 if options(1)==1
116 fprintf(1, 'Cycle %4d Error %11.6f\n', n, err)
117 end
118 % Test for termination
119 % Terminate if error increases
120 if err > errold
121 errold = err;
122 w = wold;
123 options(8) = err;
124 fprintf(1, 'Error has increased: terminating\n')
125 return;
126 end
127 if test & n > 1
128 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3))
129 options(8) = err;
130 return;
131 else
132 errold = err;
133 wold = w;
134 end
135 end
136 end
137
138 options(8) = err;
139 if (options(1) > 0)
140 disp('Warning: Maximum number of iterations has been exceeded');
141 end