annotate toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function demtrain(action);
Daniel@0 2 %DEMTRAIN Demonstrate training of MLP network.
Daniel@0 3 %
Daniel@0 4 % Description
Daniel@0 5 % DEMTRAIN brings up a simple GUI to show the training of an MLP
Daniel@0 6 % network on classification and regression problems. The user should
Daniel@0 7 % load in a dataset (which should be in Netlab format: see DATREAD),
Daniel@0 8 % select the output activation function, the number of cycles and
Daniel@0 9 % hidden units and then train the network. The scaled conjugate
Daniel@0 10 % gradient algorithm is used. A graph shows the evolution of the error:
Daniel@0 11 % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles.
Daniel@0 12 %
Daniel@0 13 % Once the network is trained, it is saved to the file MLPTRAIN.NET.
Daniel@0 14 % The results can then be viewed as a confusion matrix (for
Daniel@0 15 % classification problems) or a plot of output versus target (for
Daniel@0 16 % regression problems).
Daniel@0 17 %
Daniel@0 18 % See also
Daniel@0 19 % CONFMAT, DATREAD, MLP, NETOPT, SCG
Daniel@0 20 %
Daniel@0 21
Daniel@0 22 % Copyright (c) Ian T Nabney (1996-2001)
Daniel@0 23
Daniel@0 24 % If run without parameters, initialise gui.
Daniel@0 25 if nargin<1,
Daniel@0 26 action='initialise';
Daniel@0 27 end;
Daniel@0 28
Daniel@0 29 % Global variable to reference GUI figure
Daniel@0 30 global DEMTRAIN_FIG
Daniel@0 31 % Global array to reference sub-figures for results plots
Daniel@0 32 global DEMTRAIN_RES_FIGS
Daniel@0 33 global NUM_DEMTRAIN_RES_FIGS
Daniel@0 34
Daniel@0 35 if strcmp(action,'initialise'),
Daniel@0 36
Daniel@0 37 file = '';
Daniel@0 38 path = '.';
Daniel@0 39
Daniel@0 40 % Create FIGURE
Daniel@0 41 fig = figure( ...
Daniel@0 42 'Name', 'Netlab Demo', ...
Daniel@0 43 'NumberTitle', 'off', ...
Daniel@0 44 'Menubar', 'none', ...
Daniel@0 45 'Color', [0.7529 0.7529 0.7529], ...
Daniel@0 46 'Visible', 'on');
Daniel@0 47 % Initialise the globals
Daniel@0 48 DEMTRAIN_FIG = fig;
Daniel@0 49 DEMTRAIN_RES_FIGS = 0;
Daniel@0 50 NUM_DEMTRAIN_RES_FIGS = 0;
Daniel@0 51
Daniel@0 52 % Create GROUP for buttons
Daniel@0 53 uicontrol(fig, ...
Daniel@0 54 'Style', 'frame', ...
Daniel@0 55 'Units', 'normalized', ...
Daniel@0 56 'Position', [0.03 0.08 0.94 0.22], ...
Daniel@0 57 'BackgroundColor', [0.5 0.5 0.5]);
Daniel@0 58
Daniel@0 59 % Create MAIN axis
Daniel@0 60 hMain = axes( ...
Daniel@0 61 'Units', 'normalized', ...
Daniel@0 62 'Position', [0.10 0.5 0.80 0.40], ...
Daniel@0 63 'XColor', [0 0 0], ...
Daniel@0 64 'YColor', [0 0 0], ...
Daniel@0 65 'Visible', 'on');
Daniel@0 66
Daniel@0 67 % Create static text for FILENAME and PATH
Daniel@0 68 hFilename = uicontrol(fig, ...
Daniel@0 69 'Style', 'text', ...
Daniel@0 70 'Units', 'normalized', ...
Daniel@0 71 'BackgroundColor', [0.7529 0.7529 0.7529], ...
Daniel@0 72 'Position', [0.05 0.32 0.90 0.05], ...
Daniel@0 73 'HorizontalAlignment', 'center', ...
Daniel@0 74 'String', 'Please load data file.', ...
Daniel@0 75 'Visible', 'on');
Daniel@0 76 hPath = uicontrol(fig, ...
Daniel@0 77 'Style', 'text', ...
Daniel@0 78 'Units', 'normalized', ...
Daniel@0 79 'BackgroundColor', [0.7529 0.7529 0.7529], ...
Daniel@0 80 'Position', [0.05 0.37 0.90 0.05], ...
Daniel@0 81 'HorizontalAlignment', 'center', ...
Daniel@0 82 'String', '', ...
Daniel@0 83 'Visible', 'on');
Daniel@0 84
Daniel@0 85 % Create NO OF HIDDEN UNITS slider and text
Daniel@0 86 hSliderText = uicontrol(fig, ...
Daniel@0 87 'Style', 'text', ...
Daniel@0 88 'BackgroundColor', [0.5 0.5 0.5], ...
Daniel@0 89 'Units', 'normalized', ...
Daniel@0 90 'Position', [0.27 0.12 0.17 0.04], ...
Daniel@0 91 'HorizontalAlignment', 'right', ...
Daniel@0 92 'String', 'Hidden Units: 5');
Daniel@0 93 hSlider = uicontrol(fig, ...
Daniel@0 94 'Style', 'slider', ...
Daniel@0 95 'Units', 'normalized', ...
Daniel@0 96 'Position', [0.45 0.12 0.26 0.04], ...
Daniel@0 97 'String', 'Slider', ...
Daniel@0 98 'Min', 1, 'Max', 25, ...
Daniel@0 99 'Value', 5, ...
Daniel@0 100 'Callback', 'demtrain slider_moved');
Daniel@0 101
Daniel@0 102 % Create ITERATIONS slider and text
Daniel@0 103 hIterationsText = uicontrol(fig, ...
Daniel@0 104 'Style', 'text', ...
Daniel@0 105 'BackgroundColor', [0.5 0.5 0.5], ...
Daniel@0 106 'Units', 'normalized', ...
Daniel@0 107 'Position', [0.27 0.21 0.17 0.04], ...
Daniel@0 108 'HorizontalAlignment', 'right', ...
Daniel@0 109 'String', 'Iterations: 50');
Daniel@0 110 hIterations = uicontrol(fig, ...
Daniel@0 111 'Style', 'slider', ...
Daniel@0 112 'Units', 'normalized', ...
Daniel@0 113 'Position', [0.45 0.21 0.26 0.04], ...
Daniel@0 114 'String', 'Slider', ...
Daniel@0 115 'Min', 10, 'Max', 500, ...
Daniel@0 116 'Value', 50, ...
Daniel@0 117 'Callback', 'demtrain iterations_moved');
Daniel@0 118
Daniel@0 119 % Create ACTIVATION FUNCTION popup and text
Daniel@0 120 uicontrol(fig, ...
Daniel@0 121 'Style', 'text', ...
Daniel@0 122 'BackgroundColor', [0.5 0.5 0.5], ...
Daniel@0 123 'Units', 'normalized', ...
Daniel@0 124 'Position', [0.05 0.20 0.20 0.04], ...
Daniel@0 125 'HorizontalAlignment', 'center', ...
Daniel@0 126 'String', 'Activation Function:');
Daniel@0 127 hPopup = uicontrol(fig, ...
Daniel@0 128 'Style', 'popup', ...
Daniel@0 129 'Units', 'normalized', ...
Daniel@0 130 'Position' , [0.05 0.10 0.20 0.08], ...
Daniel@0 131 'String', 'Linear|Logistic|Softmax', ...
Daniel@0 132 'Callback', '');
Daniel@0 133
Daniel@0 134 % Create MENU
Daniel@0 135 hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', '');
Daniel@0 136 uimenu(hMenu1, 'Label', 'Select training data file', ...
Daniel@0 137 'Callback', 'demtrain get_ip_file');
Daniel@0 138 hMenu2 = uimenu('Label', 'Show Results...', 'Callback', '');
Daniel@0 139 uimenu(hMenu2, 'Label', 'Show classification results', ...
Daniel@0 140 'Callback', 'demtrain classify');
Daniel@0 141 uimenu(hMenu2, 'Label', 'Show regression results', ...
Daniel@0 142 'Callback', 'demtrain predict');
Daniel@0 143
Daniel@0 144 % Create START button
Daniel@0 145 hStart = uicontrol(fig, ...
Daniel@0 146 'Units', 'normalized', ...
Daniel@0 147 'Position' , [0.75 0.2 0.20 0.08], ...
Daniel@0 148 'String', 'Start Training', ...
Daniel@0 149 'Enable', 'off',...
Daniel@0 150 'Callback', 'demtrain start');
Daniel@0 151
Daniel@0 152 % Create CLOSE button
Daniel@0 153 uicontrol(fig, ...
Daniel@0 154 'Units', 'normalized', ...
Daniel@0 155 'Position' , [0.75 0.1 0.20 0.08], ...
Daniel@0 156 'String', 'Close', ...
Daniel@0 157 'Callback', 'demtrain close');
Daniel@0 158
Daniel@0 159 % Save handles of important UI objects
Daniel@0 160 hndlList = [hSlider hSliderText hFilename hPath hPopup ...
Daniel@0 161 hIterations hIterationsText hStart];
Daniel@0 162 set(fig, 'UserData', hndlList);
Daniel@0 163 % Hide window from command line
Daniel@0 164 set(fig, 'HandleVisibility', 'callback');
Daniel@0 165
Daniel@0 166
Daniel@0 167 elseif strcmp(action, 'slider_moved'),
Daniel@0 168
Daniel@0 169 % Slider has been moved.
Daniel@0 170
Daniel@0 171 hndlList = get(gcf, 'UserData');
Daniel@0 172 hSlider = hndlList(1);
Daniel@0 173 hSliderText = hndlList(2);
Daniel@0 174
Daniel@0 175 val = get(hSlider, 'Value');
Daniel@0 176 if rem(val, 1) < 0.5, % Force up and down arrows to work!
Daniel@0 177 val = ceil(val);
Daniel@0 178 else
Daniel@0 179 val = floor(val);
Daniel@0 180 end;
Daniel@0 181 set(hSlider, 'Value', val);
Daniel@0 182 set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]);
Daniel@0 183
Daniel@0 184
Daniel@0 185 elseif strcmp(action, 'iterations_moved'),
Daniel@0 186
Daniel@0 187 % Slider has been moved.
Daniel@0 188
Daniel@0 189 hndlList = get(gcf, 'UserData');
Daniel@0 190 hSlider = hndlList(6);
Daniel@0 191 hSliderText = hndlList(7);
Daniel@0 192
Daniel@0 193 val = get(hSlider, 'Value');
Daniel@0 194 set(hSliderText, 'String', ['Iterations: ' int2str(val)]);
Daniel@0 195
Daniel@0 196 elseif strcmp(action, 'get_ip_file'),
Daniel@0 197
Daniel@0 198 % Get data file button pressed.
Daniel@0 199
Daniel@0 200 hndlList = get(gcf, 'UserData');
Daniel@0 201
Daniel@0 202 [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50);
Daniel@0 203
Daniel@0 204 if strcmp(file, '') | file == 0,
Daniel@0 205 set(hndlList(3), 'String', 'No data file loaded.');
Daniel@0 206 set(hndlList(4), 'String', '');
Daniel@0 207 else
Daniel@0 208 set(hndlList(3), 'String', file);
Daniel@0 209 set(hndlList(4), 'String', path);
Daniel@0 210 end;
Daniel@0 211
Daniel@0 212 % Enable training button
Daniel@0 213 set(hndlList(8), 'Enable', 'on');
Daniel@0 214
Daniel@0 215 set(gcf, 'UserData', hndlList);
Daniel@0 216
Daniel@0 217 elseif strcmp(action, 'start'),
Daniel@0 218
Daniel@0 219 % Start training
Daniel@0 220
Daniel@0 221 % Get handles of and values from UI objects
Daniel@0 222 hndlList = get(gcf, 'UserData');
Daniel@0 223 hSlider = hndlList(1); % No of hidden units
Daniel@0 224 hIterations = hndlList(6);
Daniel@0 225 iterations = get(hIterations, 'Value');
Daniel@0 226
Daniel@0 227 hFilename = hndlList(3); % Data file name
Daniel@0 228 filename = get(hFilename, 'String');
Daniel@0 229
Daniel@0 230 hPath = hndlList(4); % Data file path
Daniel@0 231 path = get(hPath, 'String');
Daniel@0 232
Daniel@0 233 hPopup = hndlList(5); % Activation function
Daniel@0 234 if get(hPopup, 'Value') == 1,
Daniel@0 235 act_fn = 'linear';
Daniel@0 236 elseif get(hPopup, 'Value') == 2,
Daniel@0 237 act_fn = 'logistic';
Daniel@0 238 else
Daniel@0 239 act_fn = 'softmax';
Daniel@0 240 end;
Daniel@0 241 nhidden = get(hSlider, 'Value');
Daniel@0 242
Daniel@0 243 % Check data file exists
Daniel@0 244 if fopen([path '/' filename]) == -1,
Daniel@0 245 errordlg('Training data file has not been selected.', 'Error');
Daniel@0 246 else
Daniel@0 247 % Load data file
Daniel@0 248 [x,t,nin,nout,ndata] = datread([path filename]);
Daniel@0 249
Daniel@0 250 % Call MLPTRAIN function repeatedly, while drawing training graph.
Daniel@0 251 figure(DEMTRAIN_FIG);
Daniel@0 252 hold on;
Daniel@0 253
Daniel@0 254 title('Training - please wait.');
Daniel@0 255
Daniel@0 256 % Create net and find initial error
Daniel@0 257 net = mlp(size(x, 2), nhidden, size(t, 2), act_fn);
Daniel@0 258 % Initialise network with inverse variance of 10
Daniel@0 259 net = mlpinit(net, 10);
Daniel@0 260 error = mlperr(net, x, t);
Daniel@0 261 % Work out reporting step: should be sufficiently big to let training
Daniel@0 262 % algorithm have a chance
Daniel@0 263 step = max(ceil(iterations / 50), 5);
Daniel@0 264
Daniel@0 265 % Refresh and rescale axis.
Daniel@0 266 cla;
Daniel@0 267 max = error;
Daniel@0 268 min = max/10;
Daniel@0 269 set(gca, 'YScale', 'log');
Daniel@0 270 ylabel('log Error');
Daniel@0 271 xlabel('No. iterations');
Daniel@0 272 axis([0 iterations min max+1]);
Daniel@0 273 iold = 0;
Daniel@0 274 errold = error;
Daniel@0 275 % Plot circle to show error of last iteration
Daniel@0 276 % Setting erase mode to none prevents screen flashing during
Daniel@0 277 % training
Daniel@0 278 plot(0, error, 'ro', 'EraseMode', 'none');
Daniel@0 279 hold on
Daniel@0 280 drawnow; % Force redraw
Daniel@0 281 for i = step-1:step:iterations,
Daniel@0 282 [net, error] = mlptrain(net, x, t, step);
Daniel@0 283 % Plot line from last point to new point.
Daniel@0 284 line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none');
Daniel@0 285 iold = i;
Daniel@0 286 errold = error;
Daniel@0 287
Daniel@0 288 % If new point off scale, redraw axes.
Daniel@0 289 if error > max,
Daniel@0 290 max = error;
Daniel@0 291 axis([0 iterations min max+1]);
Daniel@0 292 end;
Daniel@0 293 if error < min
Daniel@0 294 min = error/10;
Daniel@0 295 axis([0 iterations min max+1]);
Daniel@0 296 end
Daniel@0 297 % Plot circle to show error of last iteration
Daniel@0 298 plot(i, error, 'ro', 'EraseMode', 'none');
Daniel@0 299 drawnow; % Force redraw
Daniel@0 300 end;
Daniel@0 301 save mlptrain.net net
Daniel@0 302 zoom on;
Daniel@0 303
Daniel@0 304 title(['Training complete. Final error=', num2str(error)]);
Daniel@0 305
Daniel@0 306 end;
Daniel@0 307
Daniel@0 308 elseif strcmp(action, 'close'),
Daniel@0 309
Daniel@0 310 % Close all the figures we have created
Daniel@0 311 close(DEMTRAIN_FIG);
Daniel@0 312 for n = 1:NUM_DEMTRAIN_RES_FIGS
Daniel@0 313 if ishandle(DEMTRAIN_RES_FIGS(n))
Daniel@0 314 close(DEMTRAIN_RES_FIGS(n));
Daniel@0 315 end
Daniel@0 316 end
Daniel@0 317
Daniel@0 318 elseif strcmp(action, 'classify'),
Daniel@0 319
Daniel@0 320 if fopen('mlptrain.net') == -1,
Daniel@0 321 errordlg('You have not yet trained the network.', 'Error');
Daniel@0 322 else
Daniel@0 323
Daniel@0 324 hndlList = get(gcf, 'UserData');
Daniel@0 325 filename = get(hndlList(3), 'String');
Daniel@0 326 path = get(hndlList(4), 'String');
Daniel@0 327 [x,t,nin,nout,ndata] = datread([path filename]);
Daniel@0 328 load mlptrain.net net -mat
Daniel@0 329 y = mlpfwd(net, x);
Daniel@0 330
Daniel@0 331 % Save results figure so that it can be closed later
Daniel@0 332 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
Daniel@0 333 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t);
Daniel@0 334
Daniel@0 335 end;
Daniel@0 336
Daniel@0 337 elseif strcmp(action, 'predict'),
Daniel@0 338
Daniel@0 339 if fopen('mlptrain.net') == -1,
Daniel@0 340 errordlg('You have not yet trained the network.', 'Error');
Daniel@0 341 else
Daniel@0 342
Daniel@0 343 hndlList = get(gcf, 'UserData');
Daniel@0 344 filename = get(hndlList(3), 'String');
Daniel@0 345 path = get(hndlList(4), 'String');
Daniel@0 346 [x,t,nin,nout,ndata] = datread([path filename]);
Daniel@0 347 load mlptrain.net net -mat
Daniel@0 348 y = mlpfwd(net, x);
Daniel@0 349
Daniel@0 350 for i = 1:size(y,2),
Daniel@0 351 % Save results figure so that it can be closed later
Daniel@0 352 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
Daniel@0 353 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure;
Daniel@0 354 hold on;
Daniel@0 355 title(['Output no ' num2str(i)]);
Daniel@0 356 plot([0 1], [0 1], 'r:');
Daniel@0 357 plot(y(:,i),t(:,i), 'o');
Daniel@0 358 hold off;
Daniel@0 359 end;
Daniel@0 360 end;
Daniel@0 361
Daniel@0 362 end;