wolffd@0: function [net, options, errlog, pointlog] = olgd(net, options, x, t) wolffd@0: %OLGD On-line gradient descent optimization. wolffd@0: % wolffd@0: % Description wolffd@0: % [NET, OPTIONS, ERRLOG, POINTLOG] = OLGD(NET, OPTIONS, X, T) uses on- wolffd@0: % line gradient descent to find a local minimum of the error function wolffd@0: % for the network NET computed on the input data X and target values T. wolffd@0: % A log of the error values after each cycle is (optionally) returned wolffd@0: % in ERRLOG, and a log of the points visited is (optionally) returned wolffd@0: % in POINTLOG. Because the gradient is computed on-line (i.e. after wolffd@0: % each pattern) this can be quite inefficient in Matlab. wolffd@0: % wolffd@0: % The error function value at final weight vector is returned in wolffd@0: % OPTIONS(8). wolffd@0: % wolffd@0: % The optional parameters have the following interpretations. wolffd@0: % wolffd@0: % OPTIONS(1) is set to 1 to display error values; also logs error wolffd@0: % values in the return argument ERRLOG, and the points visited in the wolffd@0: % return argument POINTSLOG. If OPTIONS(1) is set to 0, then only wolffd@0: % warning messages are displayed. If OPTIONS(1) is -1, then nothing is wolffd@0: % displayed. wolffd@0: % wolffd@0: % OPTIONS(2) is the precision required for the value of X at the wolffd@0: % solution. If the absolute difference between the values of X between wolffd@0: % two successive steps is less than OPTIONS(2), then this condition is wolffd@0: % satisfied. wolffd@0: % wolffd@0: % OPTIONS(3) is the precision required of the objective function at the wolffd@0: % solution. If the absolute difference between the error functions wolffd@0: % between two successive steps is less than OPTIONS(3), then this wolffd@0: % condition is satisfied. Both this and the previous condition must be wolffd@0: % satisfied for termination. Note that testing the function value at wolffd@0: % each iteration roughly halves the speed of the algorithm. wolffd@0: % wolffd@0: % OPTIONS(5) determines whether the patterns are sampled randomly with wolffd@0: % replacement. If it is 0 (the default), then patterns are sampled in wolffd@0: % order. wolffd@0: % wolffd@0: % OPTIONS(6) determines if the learning rate decays. If it is 1 then wolffd@0: % the learning rate decays at a rate of 1/T. If it is 0 (the default) wolffd@0: % then the learning rate is constant. wolffd@0: % wolffd@0: % OPTIONS(9) should be set to 1 to check the user defined gradient wolffd@0: % function. wolffd@0: % wolffd@0: % OPTIONS(10) returns the total number of function evaluations wolffd@0: % (including those in any line searches). wolffd@0: % wolffd@0: % OPTIONS(11) returns the total number of gradient evaluations. wolffd@0: % wolffd@0: % OPTIONS(14) is the maximum number of iterations (passes through the wolffd@0: % complete pattern set); default 100. wolffd@0: % wolffd@0: % OPTIONS(17) is the momentum; default 0.5. wolffd@0: % wolffd@0: % OPTIONS(18) is the learning rate; default 0.01. wolffd@0: % wolffd@0: % See also wolffd@0: % GRADDESC wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % Set up the options. wolffd@0: if length(options) < 18 wolffd@0: error('Options vector too short') wolffd@0: end wolffd@0: wolffd@0: if (options(14)) wolffd@0: niters = options(14); wolffd@0: else wolffd@0: niters = 100; wolffd@0: end wolffd@0: wolffd@0: % Learning rate: must be positive wolffd@0: if (options(18) > 0) wolffd@0: eta = options(18); wolffd@0: else wolffd@0: eta = 0.01; wolffd@0: end wolffd@0: % Save initial learning rate for annealing wolffd@0: lr = eta; wolffd@0: % Momentum term: allow zero momentum wolffd@0: if (options(17) >= 0) wolffd@0: mu = options(17); wolffd@0: else wolffd@0: mu = 0.5; wolffd@0: end wolffd@0: wolffd@0: pakstr = [net.type, 'pak']; wolffd@0: unpakstr = [net.type, 'unpak']; wolffd@0: wolffd@0: % Extract initial weights from the network wolffd@0: w = feval(pakstr, net); wolffd@0: wolffd@0: display = options(1); wolffd@0: wolffd@0: % Work out if we need to compute f at each iteration. wolffd@0: % Needed if display results or if termination wolffd@0: % criterion requires it. wolffd@0: fcneval = (display | options(3)); wolffd@0: wolffd@0: % Check gradients wolffd@0: if (options(9)) wolffd@0: feval('gradchek', w, 'neterr', 'netgrad', net, x, t); wolffd@0: end wolffd@0: wolffd@0: dwold = zeros(1, length(w)); wolffd@0: fold = 0; % Must be initialised so that termination test can be performed wolffd@0: ndata = size(x, 1); wolffd@0: wolffd@0: if fcneval wolffd@0: fnew = neterr(w, net, x, t); wolffd@0: options(10) = options(10) + 1; wolffd@0: fold = fnew; wolffd@0: end wolffd@0: wolffd@0: j = 1; wolffd@0: if nargout >= 3 wolffd@0: errlog(j, :) = fnew; wolffd@0: if nargout == 4 wolffd@0: pointlog(j, :) = w; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: % Main optimization loop. wolffd@0: while j <= niters wolffd@0: wold = w; wolffd@0: if options(5) wolffd@0: % Randomise order of pattern presentation: with replacement wolffd@0: pnum = ceil(rand(ndata, 1).*ndata); wolffd@0: else wolffd@0: pnum = 1:ndata; wolffd@0: end wolffd@0: for k = 1:ndata wolffd@0: grad = netgrad(w, net, x(pnum(k),:), t(pnum(k),:)); wolffd@0: if options(6) wolffd@0: % Let learning rate decrease as 1/t wolffd@0: lr = eta/((j-1)*ndata + k); wolffd@0: end wolffd@0: dw = mu*dwold - lr*grad; wolffd@0: w = w + dw; wolffd@0: dwold = dw; wolffd@0: end wolffd@0: options(11) = options(11) + 1; % Increment gradient evaluation count wolffd@0: if fcneval wolffd@0: fold = fnew; wolffd@0: fnew = neterr(w, net, x, t); wolffd@0: options(10) = options(10) + 1; wolffd@0: end wolffd@0: if display wolffd@0: fprintf(1, 'Iteration %5d Error %11.8f\n', j, fnew); wolffd@0: end wolffd@0: j = j + 1; wolffd@0: if nargout >= 3 wolffd@0: errlog(j) = fnew; wolffd@0: if nargout == 4 wolffd@0: pointlog(j, :) = w; wolffd@0: end wolffd@0: end wolffd@0: if (max(abs(w - wold)) < options(2) & abs(fnew - fold) < options(3)) wolffd@0: % Termination criteria are met wolffd@0: options(8) = fnew; wolffd@0: net = feval(unpakstr, net, w); wolffd@0: return; wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: if fcneval wolffd@0: options(8) = fnew; wolffd@0: else wolffd@0: % Return error on entire dataset wolffd@0: options(8) = neterr(w, net, x, t); wolffd@0: options(10) = options(10) + 1; wolffd@0: end wolffd@0: if (options(1) >= 0) wolffd@0: disp(maxitmess); wolffd@0: end wolffd@0: wolffd@0: net = feval(unpakstr, net, w);