wolffd@0: function demtrain(action); wolffd@0: %DEMTRAIN Demonstrate training of MLP network. wolffd@0: % wolffd@0: % Description wolffd@0: % DEMTRAIN brings up a simple GUI to show the training of an MLP wolffd@0: % network on classification and regression problems. The user should wolffd@0: % load in a dataset (which should be in Netlab format: see DATREAD), wolffd@0: % select the output activation function, the number of cycles and wolffd@0: % hidden units and then train the network. The scaled conjugate wolffd@0: % gradient algorithm is used. A graph shows the evolution of the error: wolffd@0: % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles. wolffd@0: % wolffd@0: % Once the network is trained, it is saved to the file MLPTRAIN.NET. wolffd@0: % The results can then be viewed as a confusion matrix (for wolffd@0: % classification problems) or a plot of output versus target (for wolffd@0: % regression problems). wolffd@0: % wolffd@0: % See also wolffd@0: % CONFMAT, DATREAD, MLP, NETOPT, SCG wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: % If run without parameters, initialise gui. wolffd@0: if nargin<1, wolffd@0: action='initialise'; wolffd@0: end; wolffd@0: wolffd@0: % Global variable to reference GUI figure wolffd@0: global DEMTRAIN_FIG wolffd@0: % Global array to reference sub-figures for results plots wolffd@0: global DEMTRAIN_RES_FIGS wolffd@0: global NUM_DEMTRAIN_RES_FIGS wolffd@0: wolffd@0: if strcmp(action,'initialise'), wolffd@0: wolffd@0: file = ''; wolffd@0: path = '.'; wolffd@0: wolffd@0: % Create FIGURE wolffd@0: fig = figure( ... wolffd@0: 'Name', 'Netlab Demo', ... wolffd@0: 'NumberTitle', 'off', ... wolffd@0: 'Menubar', 'none', ... wolffd@0: 'Color', [0.7529 0.7529 0.7529], ... wolffd@0: 'Visible', 'on'); wolffd@0: % Initialise the globals wolffd@0: DEMTRAIN_FIG = fig; wolffd@0: DEMTRAIN_RES_FIGS = 0; wolffd@0: NUM_DEMTRAIN_RES_FIGS = 0; wolffd@0: wolffd@0: % Create GROUP for buttons wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'frame', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.03 0.08 0.94 0.22], ... wolffd@0: 'BackgroundColor', [0.5 0.5 0.5]); wolffd@0: wolffd@0: % Create MAIN axis wolffd@0: hMain = axes( ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.10 0.5 0.80 0.40], ... wolffd@0: 'XColor', [0 0 0], ... wolffd@0: 'YColor', [0 0 0], ... wolffd@0: 'Visible', 'on'); wolffd@0: wolffd@0: % Create static text for FILENAME and PATH wolffd@0: hFilename = uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.7529 0.7529 0.7529], ... wolffd@0: 'Position', [0.05 0.32 0.90 0.05], ... wolffd@0: 'HorizontalAlignment', 'center', ... wolffd@0: 'String', 'Please load data file.', ... wolffd@0: 'Visible', 'on'); wolffd@0: hPath = uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.7529 0.7529 0.7529], ... wolffd@0: 'Position', [0.05 0.37 0.90 0.05], ... wolffd@0: 'HorizontalAlignment', 'center', ... wolffd@0: 'String', '', ... wolffd@0: 'Visible', 'on'); wolffd@0: wolffd@0: % Create NO OF HIDDEN UNITS slider and text wolffd@0: hSliderText = uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'BackgroundColor', [0.5 0.5 0.5], ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.27 0.12 0.17 0.04], ... wolffd@0: 'HorizontalAlignment', 'right', ... wolffd@0: 'String', 'Hidden Units: 5'); wolffd@0: hSlider = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.45 0.12 0.26 0.04], ... wolffd@0: 'String', 'Slider', ... wolffd@0: 'Min', 1, 'Max', 25, ... wolffd@0: 'Value', 5, ... wolffd@0: 'Callback', 'demtrain slider_moved'); wolffd@0: wolffd@0: % Create ITERATIONS slider and text wolffd@0: hIterationsText = uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'BackgroundColor', [0.5 0.5 0.5], ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.27 0.21 0.17 0.04], ... wolffd@0: 'HorizontalAlignment', 'right', ... wolffd@0: 'String', 'Iterations: 50'); wolffd@0: hIterations = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.45 0.21 0.26 0.04], ... wolffd@0: 'String', 'Slider', ... wolffd@0: 'Min', 10, 'Max', 500, ... wolffd@0: 'Value', 50, ... wolffd@0: 'Callback', 'demtrain iterations_moved'); wolffd@0: wolffd@0: % Create ACTIVATION FUNCTION popup and text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'BackgroundColor', [0.5 0.5 0.5], ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.05 0.20 0.20 0.04], ... wolffd@0: 'HorizontalAlignment', 'center', ... wolffd@0: 'String', 'Activation Function:'); wolffd@0: hPopup = uicontrol(fig, ... wolffd@0: 'Style', 'popup', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position' , [0.05 0.10 0.20 0.08], ... wolffd@0: 'String', 'Linear|Logistic|Softmax', ... wolffd@0: 'Callback', ''); wolffd@0: wolffd@0: % Create MENU wolffd@0: hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', ''); wolffd@0: uimenu(hMenu1, 'Label', 'Select training data file', ... wolffd@0: 'Callback', 'demtrain get_ip_file'); wolffd@0: hMenu2 = uimenu('Label', 'Show Results...', 'Callback', ''); wolffd@0: uimenu(hMenu2, 'Label', 'Show classification results', ... wolffd@0: 'Callback', 'demtrain classify'); wolffd@0: uimenu(hMenu2, 'Label', 'Show regression results', ... wolffd@0: 'Callback', 'demtrain predict'); wolffd@0: wolffd@0: % Create START button wolffd@0: hStart = uicontrol(fig, ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position' , [0.75 0.2 0.20 0.08], ... wolffd@0: 'String', 'Start Training', ... wolffd@0: 'Enable', 'off',... wolffd@0: 'Callback', 'demtrain start'); wolffd@0: wolffd@0: % Create CLOSE button wolffd@0: uicontrol(fig, ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position' , [0.75 0.1 0.20 0.08], ... wolffd@0: 'String', 'Close', ... wolffd@0: 'Callback', 'demtrain close'); wolffd@0: wolffd@0: % Save handles of important UI objects wolffd@0: hndlList = [hSlider hSliderText hFilename hPath hPopup ... wolffd@0: hIterations hIterationsText hStart]; wolffd@0: set(fig, 'UserData', hndlList); wolffd@0: % Hide window from command line wolffd@0: set(fig, 'HandleVisibility', 'callback'); wolffd@0: wolffd@0: wolffd@0: elseif strcmp(action, 'slider_moved'), wolffd@0: wolffd@0: % Slider has been moved. wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: hSlider = hndlList(1); wolffd@0: hSliderText = hndlList(2); wolffd@0: wolffd@0: val = get(hSlider, 'Value'); wolffd@0: if rem(val, 1) < 0.5, % Force up and down arrows to work! wolffd@0: val = ceil(val); wolffd@0: else wolffd@0: val = floor(val); wolffd@0: end; wolffd@0: set(hSlider, 'Value', val); wolffd@0: set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]); wolffd@0: wolffd@0: wolffd@0: elseif strcmp(action, 'iterations_moved'), wolffd@0: wolffd@0: % Slider has been moved. wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: hSlider = hndlList(6); wolffd@0: hSliderText = hndlList(7); wolffd@0: wolffd@0: val = get(hSlider, 'Value'); wolffd@0: set(hSliderText, 'String', ['Iterations: ' int2str(val)]); wolffd@0: wolffd@0: elseif strcmp(action, 'get_ip_file'), wolffd@0: wolffd@0: % Get data file button pressed. wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: wolffd@0: [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50); wolffd@0: wolffd@0: if strcmp(file, '') | file == 0, wolffd@0: set(hndlList(3), 'String', 'No data file loaded.'); wolffd@0: set(hndlList(4), 'String', ''); wolffd@0: else wolffd@0: set(hndlList(3), 'String', file); wolffd@0: set(hndlList(4), 'String', path); wolffd@0: end; wolffd@0: wolffd@0: % Enable training button wolffd@0: set(hndlList(8), 'Enable', 'on'); wolffd@0: wolffd@0: set(gcf, 'UserData', hndlList); wolffd@0: wolffd@0: elseif strcmp(action, 'start'), wolffd@0: wolffd@0: % Start training wolffd@0: wolffd@0: % Get handles of and values from UI objects wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: hSlider = hndlList(1); % No of hidden units wolffd@0: hIterations = hndlList(6); wolffd@0: iterations = get(hIterations, 'Value'); wolffd@0: wolffd@0: hFilename = hndlList(3); % Data file name wolffd@0: filename = get(hFilename, 'String'); wolffd@0: wolffd@0: hPath = hndlList(4); % Data file path wolffd@0: path = get(hPath, 'String'); wolffd@0: wolffd@0: hPopup = hndlList(5); % Activation function wolffd@0: if get(hPopup, 'Value') == 1, wolffd@0: act_fn = 'linear'; wolffd@0: elseif get(hPopup, 'Value') == 2, wolffd@0: act_fn = 'logistic'; wolffd@0: else wolffd@0: act_fn = 'softmax'; wolffd@0: end; wolffd@0: nhidden = get(hSlider, 'Value'); wolffd@0: wolffd@0: % Check data file exists wolffd@0: if fopen([path '/' filename]) == -1, wolffd@0: errordlg('Training data file has not been selected.', 'Error'); wolffd@0: else wolffd@0: % Load data file wolffd@0: [x,t,nin,nout,ndata] = datread([path filename]); wolffd@0: wolffd@0: % Call MLPTRAIN function repeatedly, while drawing training graph. wolffd@0: figure(DEMTRAIN_FIG); wolffd@0: hold on; wolffd@0: wolffd@0: title('Training - please wait.'); wolffd@0: wolffd@0: % Create net and find initial error wolffd@0: net = mlp(size(x, 2), nhidden, size(t, 2), act_fn); wolffd@0: % Initialise network with inverse variance of 10 wolffd@0: net = mlpinit(net, 10); wolffd@0: error = mlperr(net, x, t); wolffd@0: % Work out reporting step: should be sufficiently big to let training wolffd@0: % algorithm have a chance wolffd@0: step = max(ceil(iterations / 50), 5); wolffd@0: wolffd@0: % Refresh and rescale axis. wolffd@0: cla; wolffd@0: max = error; wolffd@0: min = max/10; wolffd@0: set(gca, 'YScale', 'log'); wolffd@0: ylabel('log Error'); wolffd@0: xlabel('No. iterations'); wolffd@0: axis([0 iterations min max+1]); wolffd@0: iold = 0; wolffd@0: errold = error; wolffd@0: % Plot circle to show error of last iteration wolffd@0: % Setting erase mode to none prevents screen flashing during wolffd@0: % training wolffd@0: plot(0, error, 'ro', 'EraseMode', 'none'); wolffd@0: hold on wolffd@0: drawnow; % Force redraw wolffd@0: for i = step-1:step:iterations, wolffd@0: [net, error] = mlptrain(net, x, t, step); wolffd@0: % Plot line from last point to new point. wolffd@0: line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none'); wolffd@0: iold = i; wolffd@0: errold = error; wolffd@0: wolffd@0: % If new point off scale, redraw axes. wolffd@0: if error > max, wolffd@0: max = error; wolffd@0: axis([0 iterations min max+1]); wolffd@0: end; wolffd@0: if error < min wolffd@0: min = error/10; wolffd@0: axis([0 iterations min max+1]); wolffd@0: end wolffd@0: % Plot circle to show error of last iteration wolffd@0: plot(i, error, 'ro', 'EraseMode', 'none'); wolffd@0: drawnow; % Force redraw wolffd@0: end; wolffd@0: save mlptrain.net net wolffd@0: zoom on; wolffd@0: wolffd@0: title(['Training complete. Final error=', num2str(error)]); wolffd@0: wolffd@0: end; wolffd@0: wolffd@0: elseif strcmp(action, 'close'), wolffd@0: wolffd@0: % Close all the figures we have created wolffd@0: close(DEMTRAIN_FIG); wolffd@0: for n = 1:NUM_DEMTRAIN_RES_FIGS wolffd@0: if ishandle(DEMTRAIN_RES_FIGS(n)) wolffd@0: close(DEMTRAIN_RES_FIGS(n)); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: elseif strcmp(action, 'classify'), wolffd@0: wolffd@0: if fopen('mlptrain.net') == -1, wolffd@0: errordlg('You have not yet trained the network.', 'Error'); wolffd@0: else wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: filename = get(hndlList(3), 'String'); wolffd@0: path = get(hndlList(4), 'String'); wolffd@0: [x,t,nin,nout,ndata] = datread([path filename]); wolffd@0: load mlptrain.net net -mat wolffd@0: y = mlpfwd(net, x); wolffd@0: wolffd@0: % Save results figure so that it can be closed later wolffd@0: NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; wolffd@0: DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t); wolffd@0: wolffd@0: end; wolffd@0: wolffd@0: elseif strcmp(action, 'predict'), wolffd@0: wolffd@0: if fopen('mlptrain.net') == -1, wolffd@0: errordlg('You have not yet trained the network.', 'Error'); wolffd@0: else wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: filename = get(hndlList(3), 'String'); wolffd@0: path = get(hndlList(4), 'String'); wolffd@0: [x,t,nin,nout,ndata] = datread([path filename]); wolffd@0: load mlptrain.net net -mat wolffd@0: y = mlpfwd(net, x); wolffd@0: wolffd@0: for i = 1:size(y,2), wolffd@0: % Save results figure so that it can be closed later wolffd@0: NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; wolffd@0: DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure; wolffd@0: hold on; wolffd@0: title(['Output no ' num2str(i)]); wolffd@0: plot([0 1], [0 1], 'r:'); wolffd@0: plot(y(:,i),t(:,i), 'o'); wolffd@0: hold off; wolffd@0: end; wolffd@0: end; wolffd@0: wolffd@0: end;