Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlabKPM/glmtrain_weighted.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [net, options] = glmtrain_weighted(net, options, x, t, eso_w, alfa) | |
2 %GLMTRAIN Specialised training of generalized linear model | |
3 % | |
4 % Description | |
5 % NET = GLMTRAIN(NET, OPTIONS, X, T) uses the iterative reweighted | |
6 % least squares (IRLS) algorithm to set the weights in the generalized | |
7 % linear model structure NET. This is a more efficient alternative to | |
8 % using GLMERR and GLMGRAD and a non-linear optimisation routine | |
9 % through NETOPT. Note that for linear outputs, a single pass through | |
10 % the algorithm is all that is required, since the error function is | |
11 % quadratic in the weights. The error function value at the final set | |
12 % of weights is returned in OPTIONS(8). Each row of X corresponds to | |
13 % one input vector and each row of T corresponds to one target vector. | |
14 % | |
15 % The optional parameters have the following interpretations. | |
16 % | |
17 % OPTIONS(1) is set to 1 to display error values during training. If | |
18 % OPTIONS(1) is set to 0, then only warning messages are displayed. If | |
19 % OPTIONS(1) is -1, then nothing is displayed. | |
20 % | |
21 % OPTIONS(2) is a measure of the precision required for the value of | |
22 % the weights W at the solution. | |
23 % | |
24 % OPTIONS(3) is a measure of the precision required of the objective | |
25 % function at the solution. Both this and the previous condition must | |
26 % be satisfied for termination. | |
27 % | |
28 % OPTIONS(5) is set to 1 if an approximation to the Hessian (which | |
29 % assumes that all outputs are independent) is used for softmax | |
30 % outputs. With the default value of 0 the exact Hessian (which is more | |
31 % expensive to compute) is used. | |
32 % | |
33 % OPTIONS(14) is the maximum number of iterations for the IRLS | |
34 % algorithm; default 100. | |
35 % | |
36 % See also | |
37 % GLM, GLMERR, GLMGRAD | |
38 % | |
39 | |
40 % Copyright (c) Christopher M Bishop, Ian T Nabney (1996, 1997) | |
41 | |
42 % Check arguments for consistency | |
43 errstring = consist(net, 'glm', x, t); | |
44 if ~errstring | |
45 error(errstring); | |
46 end | |
47 | |
48 if(~options(14)) | |
49 options(14) = 100; | |
50 end | |
51 | |
52 display = options(1); | |
53 | |
54 test = (options(2) | options(3)); % Do we need to test for termination? | |
55 | |
56 ndata = size(x, 1); | |
57 | |
58 inputs = [x ones(ndata, 1)]; % Add a column of ones for the bias | |
59 | |
60 % Use weighted iterative reweighted least squares (WIRLS) | |
61 e = ones(1, net.nin+1); | |
62 for n = 1:options(14) | |
63 | |
64 %switch net.actfn | |
65 switch net.outfn | |
66 case 'softmax' | |
67 if n == 1 | |
68 p = (t + (1/size(t, 2)))/2; % Initialise model: ensure that row sum of p is one no matter | |
69 act = log(p./(1-p)); % how many classes there are | |
70 end | |
71 if options(5) == 1 | n == 1 | |
72 link_deriv = p.*(1-p); | |
73 weights = sqrt(link_deriv); % sqrt of weights | |
74 if (min(min(weights)) < eps) | |
75 fprintf(1, 'Warning: ill-conditioned weights in glmtrain\n') | |
76 return | |
77 end | |
78 z = act + (t-p)./link_deriv; | |
79 % Treat each output independently with relevant set of weights | |
80 for j = 1:net.nout | |
81 indep = inputs.*(weights(:,j)*e); | |
82 dep = z(:,j).*weights(:,j); | |
83 temp = indep\dep; | |
84 net.w1(:,j) = temp(1:net.nin); | |
85 net.b1(j) = temp(net.nin+1); | |
86 end | |
87 [err, edata, eprior, p, act] = glmerr_weighted(net, x, t, eso_w); | |
88 if n == 1 | |
89 errold = err; | |
90 wold = netpak(net); | |
91 else | |
92 w = netpak(net); | |
93 end | |
94 else | |
95 % Exact method of calculation after w first initialised | |
96 % Start by working out Hessian | |
97 Hessian = glmhess_weighted(net, x, t, eso_w); | |
98 temp = p-t; | |
99 for m=1:ndata, | |
100 temp(m,:)=eso_w(m,1)*temp(m,:); | |
101 end | |
102 gw1 = x'*(temp); | |
103 gb1 = sum(temp, 1); | |
104 gradient = [gw1(:)', gb1]; | |
105 % Now compute modification to weights | |
106 deltaw = -gradient*pinv(Hessian); | |
107 w = wold + alfa*deltaw; | |
108 net = glmunpak(net, w); | |
109 [err, edata, eprior, p] = glmerr_weighted(net, x, t, eso_w); | |
110 end | |
111 otherwise | |
112 error(['Unknown activation function ', net.actfn]); | |
113 end % switch' end | |
114 | |
115 if options(1)==1 | |
116 fprintf(1, 'Cycle %4d Error %11.6f\n', n, err) | |
117 end | |
118 % Test for termination | |
119 % Terminate if error increases | |
120 if err > errold | |
121 errold = err; | |
122 w = wold; | |
123 options(8) = err; | |
124 fprintf(1, 'Error has increased: terminating\n') | |
125 return; | |
126 end | |
127 if test & n > 1 | |
128 if (max(abs(w - wold)) < options(2) & abs(err-errold) < options(3)) | |
129 options(8) = err; | |
130 return; | |
131 else | |
132 errold = err; | |
133 wold = w; | |
134 end | |
135 end | |
136 end | |
137 | |
138 options(8) = err; | |
139 if (options(1) > 0) | |
140 disp('Warning: Maximum number of iterations has been exceeded'); | |
141 end |