Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,362 @@ +function demtrain(action); +%DEMTRAIN Demonstrate training of MLP network. +% +% Description +% DEMTRAIN brings up a simple GUI to show the training of an MLP +% network on classification and regression problems. The user should +% load in a dataset (which should be in Netlab format: see DATREAD), +% select the output activation function, the number of cycles and +% hidden units and then train the network. The scaled conjugate +% gradient algorithm is used. A graph shows the evolution of the error: +% the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles. +% +% Once the network is trained, it is saved to the file MLPTRAIN.NET. +% The results can then be viewed as a confusion matrix (for +% classification problems) or a plot of output versus target (for +% regression problems). +% +% See also +% CONFMAT, DATREAD, MLP, NETOPT, SCG +% + +% Copyright (c) Ian T Nabney (1996-2001) + +% If run without parameters, initialise gui. +if nargin<1, + action='initialise'; +end; + +% Global variable to reference GUI figure +global DEMTRAIN_FIG +% Global array to reference sub-figures for results plots +global DEMTRAIN_RES_FIGS +global NUM_DEMTRAIN_RES_FIGS + +if strcmp(action,'initialise'), + + file = ''; + path = '.'; + + % Create FIGURE + fig = figure( ... + 'Name', 'Netlab Demo', ... + 'NumberTitle', 'off', ... + 'Menubar', 'none', ... + 'Color', [0.7529 0.7529 0.7529], ... + 'Visible', 'on'); + % Initialise the globals + DEMTRAIN_FIG = fig; + DEMTRAIN_RES_FIGS = 0; + NUM_DEMTRAIN_RES_FIGS = 0; + + % Create GROUP for buttons + uicontrol(fig, ... + 'Style', 'frame', ... + 'Units', 'normalized', ... + 'Position', [0.03 0.08 0.94 0.22], ... + 'BackgroundColor', [0.5 0.5 0.5]); + + % Create MAIN axis + hMain = axes( ... + 'Units', 'normalized', ... + 'Position', [0.10 0.5 0.80 0.40], ... + 'XColor', [0 0 0], ... + 'YColor', [0 0 0], ... + 'Visible', 'on'); + + % Create static text for FILENAME and PATH + hFilename = uicontrol(fig, ... + 'Style', 'text', ... + 'Units', 'normalized', ... + 'BackgroundColor', [0.7529 0.7529 0.7529], ... + 'Position', [0.05 0.32 0.90 0.05], ... + 'HorizontalAlignment', 'center', ... + 'String', 'Please load data file.', ... + 'Visible', 'on'); + hPath = uicontrol(fig, ... + 'Style', 'text', ... + 'Units', 'normalized', ... + 'BackgroundColor', [0.7529 0.7529 0.7529], ... + 'Position', [0.05 0.37 0.90 0.05], ... + 'HorizontalAlignment', 'center', ... + 'String', '', ... + 'Visible', 'on'); + + % Create NO OF HIDDEN UNITS slider and text + hSliderText = uicontrol(fig, ... + 'Style', 'text', ... + 'BackgroundColor', [0.5 0.5 0.5], ... + 'Units', 'normalized', ... + 'Position', [0.27 0.12 0.17 0.04], ... + 'HorizontalAlignment', 'right', ... + 'String', 'Hidden Units: 5'); + hSlider = uicontrol(fig, ... + 'Style', 'slider', ... + 'Units', 'normalized', ... + 'Position', [0.45 0.12 0.26 0.04], ... + 'String', 'Slider', ... + 'Min', 1, 'Max', 25, ... + 'Value', 5, ... + 'Callback', 'demtrain slider_moved'); + + % Create ITERATIONS slider and text + hIterationsText = uicontrol(fig, ... + 'Style', 'text', ... + 'BackgroundColor', [0.5 0.5 0.5], ... + 'Units', 'normalized', ... + 'Position', [0.27 0.21 0.17 0.04], ... + 'HorizontalAlignment', 'right', ... + 'String', 'Iterations: 50'); + hIterations = uicontrol(fig, ... + 'Style', 'slider', ... + 'Units', 'normalized', ... + 'Position', [0.45 0.21 0.26 0.04], ... + 'String', 'Slider', ... + 'Min', 10, 'Max', 500, ... + 'Value', 50, ... + 'Callback', 'demtrain iterations_moved'); + + % Create ACTIVATION FUNCTION popup and text + uicontrol(fig, ... + 'Style', 'text', ... + 'BackgroundColor', [0.5 0.5 0.5], ... + 'Units', 'normalized', ... + 'Position', [0.05 0.20 0.20 0.04], ... + 'HorizontalAlignment', 'center', ... + 'String', 'Activation Function:'); + hPopup = uicontrol(fig, ... + 'Style', 'popup', ... + 'Units', 'normalized', ... + 'Position' , [0.05 0.10 0.20 0.08], ... + 'String', 'Linear|Logistic|Softmax', ... + 'Callback', ''); + + % Create MENU + hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', ''); + uimenu(hMenu1, 'Label', 'Select training data file', ... + 'Callback', 'demtrain get_ip_file'); + hMenu2 = uimenu('Label', 'Show Results...', 'Callback', ''); + uimenu(hMenu2, 'Label', 'Show classification results', ... + 'Callback', 'demtrain classify'); + uimenu(hMenu2, 'Label', 'Show regression results', ... + 'Callback', 'demtrain predict'); + + % Create START button + hStart = uicontrol(fig, ... + 'Units', 'normalized', ... + 'Position' , [0.75 0.2 0.20 0.08], ... + 'String', 'Start Training', ... + 'Enable', 'off',... + 'Callback', 'demtrain start'); + + % Create CLOSE button + uicontrol(fig, ... + 'Units', 'normalized', ... + 'Position' , [0.75 0.1 0.20 0.08], ... + 'String', 'Close', ... + 'Callback', 'demtrain close'); + + % Save handles of important UI objects + hndlList = [hSlider hSliderText hFilename hPath hPopup ... + hIterations hIterationsText hStart]; + set(fig, 'UserData', hndlList); + % Hide window from command line + set(fig, 'HandleVisibility', 'callback'); + + +elseif strcmp(action, 'slider_moved'), + + % Slider has been moved. + + hndlList = get(gcf, 'UserData'); + hSlider = hndlList(1); + hSliderText = hndlList(2); + + val = get(hSlider, 'Value'); + if rem(val, 1) < 0.5, % Force up and down arrows to work! + val = ceil(val); + else + val = floor(val); + end; + set(hSlider, 'Value', val); + set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]); + + +elseif strcmp(action, 'iterations_moved'), + + % Slider has been moved. + + hndlList = get(gcf, 'UserData'); + hSlider = hndlList(6); + hSliderText = hndlList(7); + + val = get(hSlider, 'Value'); + set(hSliderText, 'String', ['Iterations: ' int2str(val)]); + +elseif strcmp(action, 'get_ip_file'), + + % Get data file button pressed. + + hndlList = get(gcf, 'UserData'); + + [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50); + + if strcmp(file, '') | file == 0, + set(hndlList(3), 'String', 'No data file loaded.'); + set(hndlList(4), 'String', ''); + else + set(hndlList(3), 'String', file); + set(hndlList(4), 'String', path); + end; + + % Enable training button + set(hndlList(8), 'Enable', 'on'); + + set(gcf, 'UserData', hndlList); + +elseif strcmp(action, 'start'), + + % Start training + + % Get handles of and values from UI objects + hndlList = get(gcf, 'UserData'); + hSlider = hndlList(1); % No of hidden units + hIterations = hndlList(6); + iterations = get(hIterations, 'Value'); + + hFilename = hndlList(3); % Data file name + filename = get(hFilename, 'String'); + + hPath = hndlList(4); % Data file path + path = get(hPath, 'String'); + + hPopup = hndlList(5); % Activation function + if get(hPopup, 'Value') == 1, + act_fn = 'linear'; + elseif get(hPopup, 'Value') == 2, + act_fn = 'logistic'; + else + act_fn = 'softmax'; + end; + nhidden = get(hSlider, 'Value'); + + % Check data file exists + if fopen([path '/' filename]) == -1, + errordlg('Training data file has not been selected.', 'Error'); + else + % Load data file + [x,t,nin,nout,ndata] = datread([path filename]); + + % Call MLPTRAIN function repeatedly, while drawing training graph. + figure(DEMTRAIN_FIG); + hold on; + + title('Training - please wait.'); + + % Create net and find initial error + net = mlp(size(x, 2), nhidden, size(t, 2), act_fn); + % Initialise network with inverse variance of 10 + net = mlpinit(net, 10); + error = mlperr(net, x, t); + % Work out reporting step: should be sufficiently big to let training + % algorithm have a chance + step = max(ceil(iterations / 50), 5); + + % Refresh and rescale axis. + cla; + max = error; + min = max/10; + set(gca, 'YScale', 'log'); + ylabel('log Error'); + xlabel('No. iterations'); + axis([0 iterations min max+1]); + iold = 0; + errold = error; + % Plot circle to show error of last iteration + % Setting erase mode to none prevents screen flashing during + % training + plot(0, error, 'ro', 'EraseMode', 'none'); + hold on + drawnow; % Force redraw + for i = step-1:step:iterations, + [net, error] = mlptrain(net, x, t, step); + % Plot line from last point to new point. + line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none'); + iold = i; + errold = error; + + % If new point off scale, redraw axes. + if error > max, + max = error; + axis([0 iterations min max+1]); + end; + if error < min + min = error/10; + axis([0 iterations min max+1]); + end + % Plot circle to show error of last iteration + plot(i, error, 'ro', 'EraseMode', 'none'); + drawnow; % Force redraw + end; + save mlptrain.net net + zoom on; + + title(['Training complete. Final error=', num2str(error)]); + + end; + +elseif strcmp(action, 'close'), + + % Close all the figures we have created + close(DEMTRAIN_FIG); + for n = 1:NUM_DEMTRAIN_RES_FIGS + if ishandle(DEMTRAIN_RES_FIGS(n)) + close(DEMTRAIN_RES_FIGS(n)); + end + end + +elseif strcmp(action, 'classify'), + + if fopen('mlptrain.net') == -1, + errordlg('You have not yet trained the network.', 'Error'); + else + + hndlList = get(gcf, 'UserData'); + filename = get(hndlList(3), 'String'); + path = get(hndlList(4), 'String'); + [x,t,nin,nout,ndata] = datread([path filename]); + load mlptrain.net net -mat + y = mlpfwd(net, x); + + % Save results figure so that it can be closed later + NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; + DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t); + + end; + +elseif strcmp(action, 'predict'), + + if fopen('mlptrain.net') == -1, + errordlg('You have not yet trained the network.', 'Error'); + else + + hndlList = get(gcf, 'UserData'); + filename = get(hndlList(3), 'String'); + path = get(hndlList(4), 'String'); + [x,t,nin,nout,ndata] = datread([path filename]); + load mlptrain.net net -mat + y = mlpfwd(net, x); + + for i = 1:size(y,2), + % Save results figure so that it can be closed later + NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; + DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure; + hold on; + title(['Output no ' num2str(i)]); + plot([0 1], [0 1], 'r:'); + plot(y(:,i),t(:,i), 'o'); + hold off; + end; + end; + +end; \ No newline at end of file