annotate toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function demtrain(action);
wolffd@0 2 %DEMTRAIN Demonstrate training of MLP network.
wolffd@0 3 %
wolffd@0 4 % Description
wolffd@0 5 % DEMTRAIN brings up a simple GUI to show the training of an MLP
wolffd@0 6 % network on classification and regression problems. The user should
wolffd@0 7 % load in a dataset (which should be in Netlab format: see DATREAD),
wolffd@0 8 % select the output activation function, the number of cycles and
wolffd@0 9 % hidden units and then train the network. The scaled conjugate
wolffd@0 10 % gradient algorithm is used. A graph shows the evolution of the error:
wolffd@0 11 % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles.
wolffd@0 12 %
wolffd@0 13 % Once the network is trained, it is saved to the file MLPTRAIN.NET.
wolffd@0 14 % The results can then be viewed as a confusion matrix (for
wolffd@0 15 % classification problems) or a plot of output versus target (for
wolffd@0 16 % regression problems).
wolffd@0 17 %
wolffd@0 18 % See also
wolffd@0 19 % CONFMAT, DATREAD, MLP, NETOPT, SCG
wolffd@0 20 %
wolffd@0 21
wolffd@0 22 % Copyright (c) Ian T Nabney (1996-2001)
wolffd@0 23
wolffd@0 24 % If run without parameters, initialise gui.
wolffd@0 25 if nargin<1,
wolffd@0 26 action='initialise';
wolffd@0 27 end;
wolffd@0 28
wolffd@0 29 % Global variable to reference GUI figure
wolffd@0 30 global DEMTRAIN_FIG
wolffd@0 31 % Global array to reference sub-figures for results plots
wolffd@0 32 global DEMTRAIN_RES_FIGS
wolffd@0 33 global NUM_DEMTRAIN_RES_FIGS
wolffd@0 34
wolffd@0 35 if strcmp(action,'initialise'),
wolffd@0 36
wolffd@0 37 file = '';
wolffd@0 38 path = '.';
wolffd@0 39
wolffd@0 40 % Create FIGURE
wolffd@0 41 fig = figure( ...
wolffd@0 42 'Name', 'Netlab Demo', ...
wolffd@0 43 'NumberTitle', 'off', ...
wolffd@0 44 'Menubar', 'none', ...
wolffd@0 45 'Color', [0.7529 0.7529 0.7529], ...
wolffd@0 46 'Visible', 'on');
wolffd@0 47 % Initialise the globals
wolffd@0 48 DEMTRAIN_FIG = fig;
wolffd@0 49 DEMTRAIN_RES_FIGS = 0;
wolffd@0 50 NUM_DEMTRAIN_RES_FIGS = 0;
wolffd@0 51
wolffd@0 52 % Create GROUP for buttons
wolffd@0 53 uicontrol(fig, ...
wolffd@0 54 'Style', 'frame', ...
wolffd@0 55 'Units', 'normalized', ...
wolffd@0 56 'Position', [0.03 0.08 0.94 0.22], ...
wolffd@0 57 'BackgroundColor', [0.5 0.5 0.5]);
wolffd@0 58
wolffd@0 59 % Create MAIN axis
wolffd@0 60 hMain = axes( ...
wolffd@0 61 'Units', 'normalized', ...
wolffd@0 62 'Position', [0.10 0.5 0.80 0.40], ...
wolffd@0 63 'XColor', [0 0 0], ...
wolffd@0 64 'YColor', [0 0 0], ...
wolffd@0 65 'Visible', 'on');
wolffd@0 66
wolffd@0 67 % Create static text for FILENAME and PATH
wolffd@0 68 hFilename = uicontrol(fig, ...
wolffd@0 69 'Style', 'text', ...
wolffd@0 70 'Units', 'normalized', ...
wolffd@0 71 'BackgroundColor', [0.7529 0.7529 0.7529], ...
wolffd@0 72 'Position', [0.05 0.32 0.90 0.05], ...
wolffd@0 73 'HorizontalAlignment', 'center', ...
wolffd@0 74 'String', 'Please load data file.', ...
wolffd@0 75 'Visible', 'on');
wolffd@0 76 hPath = uicontrol(fig, ...
wolffd@0 77 'Style', 'text', ...
wolffd@0 78 'Units', 'normalized', ...
wolffd@0 79 'BackgroundColor', [0.7529 0.7529 0.7529], ...
wolffd@0 80 'Position', [0.05 0.37 0.90 0.05], ...
wolffd@0 81 'HorizontalAlignment', 'center', ...
wolffd@0 82 'String', '', ...
wolffd@0 83 'Visible', 'on');
wolffd@0 84
wolffd@0 85 % Create NO OF HIDDEN UNITS slider and text
wolffd@0 86 hSliderText = uicontrol(fig, ...
wolffd@0 87 'Style', 'text', ...
wolffd@0 88 'BackgroundColor', [0.5 0.5 0.5], ...
wolffd@0 89 'Units', 'normalized', ...
wolffd@0 90 'Position', [0.27 0.12 0.17 0.04], ...
wolffd@0 91 'HorizontalAlignment', 'right', ...
wolffd@0 92 'String', 'Hidden Units: 5');
wolffd@0 93 hSlider = uicontrol(fig, ...
wolffd@0 94 'Style', 'slider', ...
wolffd@0 95 'Units', 'normalized', ...
wolffd@0 96 'Position', [0.45 0.12 0.26 0.04], ...
wolffd@0 97 'String', 'Slider', ...
wolffd@0 98 'Min', 1, 'Max', 25, ...
wolffd@0 99 'Value', 5, ...
wolffd@0 100 'Callback', 'demtrain slider_moved');
wolffd@0 101
wolffd@0 102 % Create ITERATIONS slider and text
wolffd@0 103 hIterationsText = uicontrol(fig, ...
wolffd@0 104 'Style', 'text', ...
wolffd@0 105 'BackgroundColor', [0.5 0.5 0.5], ...
wolffd@0 106 'Units', 'normalized', ...
wolffd@0 107 'Position', [0.27 0.21 0.17 0.04], ...
wolffd@0 108 'HorizontalAlignment', 'right', ...
wolffd@0 109 'String', 'Iterations: 50');
wolffd@0 110 hIterations = uicontrol(fig, ...
wolffd@0 111 'Style', 'slider', ...
wolffd@0 112 'Units', 'normalized', ...
wolffd@0 113 'Position', [0.45 0.21 0.26 0.04], ...
wolffd@0 114 'String', 'Slider', ...
wolffd@0 115 'Min', 10, 'Max', 500, ...
wolffd@0 116 'Value', 50, ...
wolffd@0 117 'Callback', 'demtrain iterations_moved');
wolffd@0 118
wolffd@0 119 % Create ACTIVATION FUNCTION popup and text
wolffd@0 120 uicontrol(fig, ...
wolffd@0 121 'Style', 'text', ...
wolffd@0 122 'BackgroundColor', [0.5 0.5 0.5], ...
wolffd@0 123 'Units', 'normalized', ...
wolffd@0 124 'Position', [0.05 0.20 0.20 0.04], ...
wolffd@0 125 'HorizontalAlignment', 'center', ...
wolffd@0 126 'String', 'Activation Function:');
wolffd@0 127 hPopup = uicontrol(fig, ...
wolffd@0 128 'Style', 'popup', ...
wolffd@0 129 'Units', 'normalized', ...
wolffd@0 130 'Position' , [0.05 0.10 0.20 0.08], ...
wolffd@0 131 'String', 'Linear|Logistic|Softmax', ...
wolffd@0 132 'Callback', '');
wolffd@0 133
wolffd@0 134 % Create MENU
wolffd@0 135 hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', '');
wolffd@0 136 uimenu(hMenu1, 'Label', 'Select training data file', ...
wolffd@0 137 'Callback', 'demtrain get_ip_file');
wolffd@0 138 hMenu2 = uimenu('Label', 'Show Results...', 'Callback', '');
wolffd@0 139 uimenu(hMenu2, 'Label', 'Show classification results', ...
wolffd@0 140 'Callback', 'demtrain classify');
wolffd@0 141 uimenu(hMenu2, 'Label', 'Show regression results', ...
wolffd@0 142 'Callback', 'demtrain predict');
wolffd@0 143
wolffd@0 144 % Create START button
wolffd@0 145 hStart = uicontrol(fig, ...
wolffd@0 146 'Units', 'normalized', ...
wolffd@0 147 'Position' , [0.75 0.2 0.20 0.08], ...
wolffd@0 148 'String', 'Start Training', ...
wolffd@0 149 'Enable', 'off',...
wolffd@0 150 'Callback', 'demtrain start');
wolffd@0 151
wolffd@0 152 % Create CLOSE button
wolffd@0 153 uicontrol(fig, ...
wolffd@0 154 'Units', 'normalized', ...
wolffd@0 155 'Position' , [0.75 0.1 0.20 0.08], ...
wolffd@0 156 'String', 'Close', ...
wolffd@0 157 'Callback', 'demtrain close');
wolffd@0 158
wolffd@0 159 % Save handles of important UI objects
wolffd@0 160 hndlList = [hSlider hSliderText hFilename hPath hPopup ...
wolffd@0 161 hIterations hIterationsText hStart];
wolffd@0 162 set(fig, 'UserData', hndlList);
wolffd@0 163 % Hide window from command line
wolffd@0 164 set(fig, 'HandleVisibility', 'callback');
wolffd@0 165
wolffd@0 166
wolffd@0 167 elseif strcmp(action, 'slider_moved'),
wolffd@0 168
wolffd@0 169 % Slider has been moved.
wolffd@0 170
wolffd@0 171 hndlList = get(gcf, 'UserData');
wolffd@0 172 hSlider = hndlList(1);
wolffd@0 173 hSliderText = hndlList(2);
wolffd@0 174
wolffd@0 175 val = get(hSlider, 'Value');
wolffd@0 176 if rem(val, 1) < 0.5, % Force up and down arrows to work!
wolffd@0 177 val = ceil(val);
wolffd@0 178 else
wolffd@0 179 val = floor(val);
wolffd@0 180 end;
wolffd@0 181 set(hSlider, 'Value', val);
wolffd@0 182 set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]);
wolffd@0 183
wolffd@0 184
wolffd@0 185 elseif strcmp(action, 'iterations_moved'),
wolffd@0 186
wolffd@0 187 % Slider has been moved.
wolffd@0 188
wolffd@0 189 hndlList = get(gcf, 'UserData');
wolffd@0 190 hSlider = hndlList(6);
wolffd@0 191 hSliderText = hndlList(7);
wolffd@0 192
wolffd@0 193 val = get(hSlider, 'Value');
wolffd@0 194 set(hSliderText, 'String', ['Iterations: ' int2str(val)]);
wolffd@0 195
wolffd@0 196 elseif strcmp(action, 'get_ip_file'),
wolffd@0 197
wolffd@0 198 % Get data file button pressed.
wolffd@0 199
wolffd@0 200 hndlList = get(gcf, 'UserData');
wolffd@0 201
wolffd@0 202 [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50);
wolffd@0 203
wolffd@0 204 if strcmp(file, '') | file == 0,
wolffd@0 205 set(hndlList(3), 'String', 'No data file loaded.');
wolffd@0 206 set(hndlList(4), 'String', '');
wolffd@0 207 else
wolffd@0 208 set(hndlList(3), 'String', file);
wolffd@0 209 set(hndlList(4), 'String', path);
wolffd@0 210 end;
wolffd@0 211
wolffd@0 212 % Enable training button
wolffd@0 213 set(hndlList(8), 'Enable', 'on');
wolffd@0 214
wolffd@0 215 set(gcf, 'UserData', hndlList);
wolffd@0 216
wolffd@0 217 elseif strcmp(action, 'start'),
wolffd@0 218
wolffd@0 219 % Start training
wolffd@0 220
wolffd@0 221 % Get handles of and values from UI objects
wolffd@0 222 hndlList = get(gcf, 'UserData');
wolffd@0 223 hSlider = hndlList(1); % No of hidden units
wolffd@0 224 hIterations = hndlList(6);
wolffd@0 225 iterations = get(hIterations, 'Value');
wolffd@0 226
wolffd@0 227 hFilename = hndlList(3); % Data file name
wolffd@0 228 filename = get(hFilename, 'String');
wolffd@0 229
wolffd@0 230 hPath = hndlList(4); % Data file path
wolffd@0 231 path = get(hPath, 'String');
wolffd@0 232
wolffd@0 233 hPopup = hndlList(5); % Activation function
wolffd@0 234 if get(hPopup, 'Value') == 1,
wolffd@0 235 act_fn = 'linear';
wolffd@0 236 elseif get(hPopup, 'Value') == 2,
wolffd@0 237 act_fn = 'logistic';
wolffd@0 238 else
wolffd@0 239 act_fn = 'softmax';
wolffd@0 240 end;
wolffd@0 241 nhidden = get(hSlider, 'Value');
wolffd@0 242
wolffd@0 243 % Check data file exists
wolffd@0 244 if fopen([path '/' filename]) == -1,
wolffd@0 245 errordlg('Training data file has not been selected.', 'Error');
wolffd@0 246 else
wolffd@0 247 % Load data file
wolffd@0 248 [x,t,nin,nout,ndata] = datread([path filename]);
wolffd@0 249
wolffd@0 250 % Call MLPTRAIN function repeatedly, while drawing training graph.
wolffd@0 251 figure(DEMTRAIN_FIG);
wolffd@0 252 hold on;
wolffd@0 253
wolffd@0 254 title('Training - please wait.');
wolffd@0 255
wolffd@0 256 % Create net and find initial error
wolffd@0 257 net = mlp(size(x, 2), nhidden, size(t, 2), act_fn);
wolffd@0 258 % Initialise network with inverse variance of 10
wolffd@0 259 net = mlpinit(net, 10);
wolffd@0 260 error = mlperr(net, x, t);
wolffd@0 261 % Work out reporting step: should be sufficiently big to let training
wolffd@0 262 % algorithm have a chance
wolffd@0 263 step = max(ceil(iterations / 50), 5);
wolffd@0 264
wolffd@0 265 % Refresh and rescale axis.
wolffd@0 266 cla;
wolffd@0 267 max = error;
wolffd@0 268 min = max/10;
wolffd@0 269 set(gca, 'YScale', 'log');
wolffd@0 270 ylabel('log Error');
wolffd@0 271 xlabel('No. iterations');
wolffd@0 272 axis([0 iterations min max+1]);
wolffd@0 273 iold = 0;
wolffd@0 274 errold = error;
wolffd@0 275 % Plot circle to show error of last iteration
wolffd@0 276 % Setting erase mode to none prevents screen flashing during
wolffd@0 277 % training
wolffd@0 278 plot(0, error, 'ro', 'EraseMode', 'none');
wolffd@0 279 hold on
wolffd@0 280 drawnow; % Force redraw
wolffd@0 281 for i = step-1:step:iterations,
wolffd@0 282 [net, error] = mlptrain(net, x, t, step);
wolffd@0 283 % Plot line from last point to new point.
wolffd@0 284 line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none');
wolffd@0 285 iold = i;
wolffd@0 286 errold = error;
wolffd@0 287
wolffd@0 288 % If new point off scale, redraw axes.
wolffd@0 289 if error > max,
wolffd@0 290 max = error;
wolffd@0 291 axis([0 iterations min max+1]);
wolffd@0 292 end;
wolffd@0 293 if error < min
wolffd@0 294 min = error/10;
wolffd@0 295 axis([0 iterations min max+1]);
wolffd@0 296 end
wolffd@0 297 % Plot circle to show error of last iteration
wolffd@0 298 plot(i, error, 'ro', 'EraseMode', 'none');
wolffd@0 299 drawnow; % Force redraw
wolffd@0 300 end;
wolffd@0 301 save mlptrain.net net
wolffd@0 302 zoom on;
wolffd@0 303
wolffd@0 304 title(['Training complete. Final error=', num2str(error)]);
wolffd@0 305
wolffd@0 306 end;
wolffd@0 307
wolffd@0 308 elseif strcmp(action, 'close'),
wolffd@0 309
wolffd@0 310 % Close all the figures we have created
wolffd@0 311 close(DEMTRAIN_FIG);
wolffd@0 312 for n = 1:NUM_DEMTRAIN_RES_FIGS
wolffd@0 313 if ishandle(DEMTRAIN_RES_FIGS(n))
wolffd@0 314 close(DEMTRAIN_RES_FIGS(n));
wolffd@0 315 end
wolffd@0 316 end
wolffd@0 317
wolffd@0 318 elseif strcmp(action, 'classify'),
wolffd@0 319
wolffd@0 320 if fopen('mlptrain.net') == -1,
wolffd@0 321 errordlg('You have not yet trained the network.', 'Error');
wolffd@0 322 else
wolffd@0 323
wolffd@0 324 hndlList = get(gcf, 'UserData');
wolffd@0 325 filename = get(hndlList(3), 'String');
wolffd@0 326 path = get(hndlList(4), 'String');
wolffd@0 327 [x,t,nin,nout,ndata] = datread([path filename]);
wolffd@0 328 load mlptrain.net net -mat
wolffd@0 329 y = mlpfwd(net, x);
wolffd@0 330
wolffd@0 331 % Save results figure so that it can be closed later
wolffd@0 332 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
wolffd@0 333 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t);
wolffd@0 334
wolffd@0 335 end;
wolffd@0 336
wolffd@0 337 elseif strcmp(action, 'predict'),
wolffd@0 338
wolffd@0 339 if fopen('mlptrain.net') == -1,
wolffd@0 340 errordlg('You have not yet trained the network.', 'Error');
wolffd@0 341 else
wolffd@0 342
wolffd@0 343 hndlList = get(gcf, 'UserData');
wolffd@0 344 filename = get(hndlList(3), 'String');
wolffd@0 345 path = get(hndlList(4), 'String');
wolffd@0 346 [x,t,nin,nout,ndata] = datread([path filename]);
wolffd@0 347 load mlptrain.net net -mat
wolffd@0 348 y = mlpfwd(net, x);
wolffd@0 349
wolffd@0 350 for i = 1:size(y,2),
wolffd@0 351 % Save results figure so that it can be closed later
wolffd@0 352 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
wolffd@0 353 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure;
wolffd@0 354 hold on;
wolffd@0 355 title(['Output no ' num2str(i)]);
wolffd@0 356 plot([0 1], [0 1], 'r:');
wolffd@0 357 plot(y(:,i),t(:,i), 'o');
wolffd@0 358 hold off;
wolffd@0 359 end;
wolffd@0 360 end;
wolffd@0 361
wolffd@0 362 end;