Mercurial > hg > camir-aes2014
view toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line source
function demtrain(action); %DEMTRAIN Demonstrate training of MLP network. % % Description % DEMTRAIN brings up a simple GUI to show the training of an MLP % network on classification and regression problems. The user should % load in a dataset (which should be in Netlab format: see DATREAD), % select the output activation function, the number of cycles and % hidden units and then train the network. The scaled conjugate % gradient algorithm is used. A graph shows the evolution of the error: % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles. % % Once the network is trained, it is saved to the file MLPTRAIN.NET. % The results can then be viewed as a confusion matrix (for % classification problems) or a plot of output versus target (for % regression problems). % % See also % CONFMAT, DATREAD, MLP, NETOPT, SCG % % Copyright (c) Ian T Nabney (1996-2001) % If run without parameters, initialise gui. if nargin<1, action='initialise'; end; % Global variable to reference GUI figure global DEMTRAIN_FIG % Global array to reference sub-figures for results plots global DEMTRAIN_RES_FIGS global NUM_DEMTRAIN_RES_FIGS if strcmp(action,'initialise'), file = ''; path = '.'; % Create FIGURE fig = figure( ... 'Name', 'Netlab Demo', ... 'NumberTitle', 'off', ... 'Menubar', 'none', ... 'Color', [0.7529 0.7529 0.7529], ... 'Visible', 'on'); % Initialise the globals DEMTRAIN_FIG = fig; DEMTRAIN_RES_FIGS = 0; NUM_DEMTRAIN_RES_FIGS = 0; % Create GROUP for buttons uicontrol(fig, ... 'Style', 'frame', ... 'Units', 'normalized', ... 'Position', [0.03 0.08 0.94 0.22], ... 'BackgroundColor', [0.5 0.5 0.5]); % Create MAIN axis hMain = axes( ... 'Units', 'normalized', ... 'Position', [0.10 0.5 0.80 0.40], ... 'XColor', [0 0 0], ... 'YColor', [0 0 0], ... 'Visible', 'on'); % Create static text for FILENAME and PATH hFilename = uicontrol(fig, ... 'Style', 'text', ... 'Units', 'normalized', ... 'BackgroundColor', [0.7529 0.7529 0.7529], ... 'Position', [0.05 0.32 0.90 0.05], ... 'HorizontalAlignment', 'center', ... 'String', 'Please load data file.', ... 'Visible', 'on'); hPath = uicontrol(fig, ... 'Style', 'text', ... 'Units', 'normalized', ... 'BackgroundColor', [0.7529 0.7529 0.7529], ... 'Position', [0.05 0.37 0.90 0.05], ... 'HorizontalAlignment', 'center', ... 'String', '', ... 'Visible', 'on'); % Create NO OF HIDDEN UNITS slider and text hSliderText = uicontrol(fig, ... 'Style', 'text', ... 'BackgroundColor', [0.5 0.5 0.5], ... 'Units', 'normalized', ... 'Position', [0.27 0.12 0.17 0.04], ... 'HorizontalAlignment', 'right', ... 'String', 'Hidden Units: 5'); hSlider = uicontrol(fig, ... 'Style', 'slider', ... 'Units', 'normalized', ... 'Position', [0.45 0.12 0.26 0.04], ... 'String', 'Slider', ... 'Min', 1, 'Max', 25, ... 'Value', 5, ... 'Callback', 'demtrain slider_moved'); % Create ITERATIONS slider and text hIterationsText = uicontrol(fig, ... 'Style', 'text', ... 'BackgroundColor', [0.5 0.5 0.5], ... 'Units', 'normalized', ... 'Position', [0.27 0.21 0.17 0.04], ... 'HorizontalAlignment', 'right', ... 'String', 'Iterations: 50'); hIterations = uicontrol(fig, ... 'Style', 'slider', ... 'Units', 'normalized', ... 'Position', [0.45 0.21 0.26 0.04], ... 'String', 'Slider', ... 'Min', 10, 'Max', 500, ... 'Value', 50, ... 'Callback', 'demtrain iterations_moved'); % Create ACTIVATION FUNCTION popup and text uicontrol(fig, ... 'Style', 'text', ... 'BackgroundColor', [0.5 0.5 0.5], ... 'Units', 'normalized', ... 'Position', [0.05 0.20 0.20 0.04], ... 'HorizontalAlignment', 'center', ... 'String', 'Activation Function:'); hPopup = uicontrol(fig, ... 'Style', 'popup', ... 'Units', 'normalized', ... 'Position' , [0.05 0.10 0.20 0.08], ... 'String', 'Linear|Logistic|Softmax', ... 'Callback', ''); % Create MENU hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', ''); uimenu(hMenu1, 'Label', 'Select training data file', ... 'Callback', 'demtrain get_ip_file'); hMenu2 = uimenu('Label', 'Show Results...', 'Callback', ''); uimenu(hMenu2, 'Label', 'Show classification results', ... 'Callback', 'demtrain classify'); uimenu(hMenu2, 'Label', 'Show regression results', ... 'Callback', 'demtrain predict'); % Create START button hStart = uicontrol(fig, ... 'Units', 'normalized', ... 'Position' , [0.75 0.2 0.20 0.08], ... 'String', 'Start Training', ... 'Enable', 'off',... 'Callback', 'demtrain start'); % Create CLOSE button uicontrol(fig, ... 'Units', 'normalized', ... 'Position' , [0.75 0.1 0.20 0.08], ... 'String', 'Close', ... 'Callback', 'demtrain close'); % Save handles of important UI objects hndlList = [hSlider hSliderText hFilename hPath hPopup ... hIterations hIterationsText hStart]; set(fig, 'UserData', hndlList); % Hide window from command line set(fig, 'HandleVisibility', 'callback'); elseif strcmp(action, 'slider_moved'), % Slider has been moved. hndlList = get(gcf, 'UserData'); hSlider = hndlList(1); hSliderText = hndlList(2); val = get(hSlider, 'Value'); if rem(val, 1) < 0.5, % Force up and down arrows to work! val = ceil(val); else val = floor(val); end; set(hSlider, 'Value', val); set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]); elseif strcmp(action, 'iterations_moved'), % Slider has been moved. hndlList = get(gcf, 'UserData'); hSlider = hndlList(6); hSliderText = hndlList(7); val = get(hSlider, 'Value'); set(hSliderText, 'String', ['Iterations: ' int2str(val)]); elseif strcmp(action, 'get_ip_file'), % Get data file button pressed. hndlList = get(gcf, 'UserData'); [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50); if strcmp(file, '') | file == 0, set(hndlList(3), 'String', 'No data file loaded.'); set(hndlList(4), 'String', ''); else set(hndlList(3), 'String', file); set(hndlList(4), 'String', path); end; % Enable training button set(hndlList(8), 'Enable', 'on'); set(gcf, 'UserData', hndlList); elseif strcmp(action, 'start'), % Start training % Get handles of and values from UI objects hndlList = get(gcf, 'UserData'); hSlider = hndlList(1); % No of hidden units hIterations = hndlList(6); iterations = get(hIterations, 'Value'); hFilename = hndlList(3); % Data file name filename = get(hFilename, 'String'); hPath = hndlList(4); % Data file path path = get(hPath, 'String'); hPopup = hndlList(5); % Activation function if get(hPopup, 'Value') == 1, act_fn = 'linear'; elseif get(hPopup, 'Value') == 2, act_fn = 'logistic'; else act_fn = 'softmax'; end; nhidden = get(hSlider, 'Value'); % Check data file exists if fopen([path '/' filename]) == -1, errordlg('Training data file has not been selected.', 'Error'); else % Load data file [x,t,nin,nout,ndata] = datread([path filename]); % Call MLPTRAIN function repeatedly, while drawing training graph. figure(DEMTRAIN_FIG); hold on; title('Training - please wait.'); % Create net and find initial error net = mlp(size(x, 2), nhidden, size(t, 2), act_fn); % Initialise network with inverse variance of 10 net = mlpinit(net, 10); error = mlperr(net, x, t); % Work out reporting step: should be sufficiently big to let training % algorithm have a chance step = max(ceil(iterations / 50), 5); % Refresh and rescale axis. cla; max = error; min = max/10; set(gca, 'YScale', 'log'); ylabel('log Error'); xlabel('No. iterations'); axis([0 iterations min max+1]); iold = 0; errold = error; % Plot circle to show error of last iteration % Setting erase mode to none prevents screen flashing during % training plot(0, error, 'ro', 'EraseMode', 'none'); hold on drawnow; % Force redraw for i = step-1:step:iterations, [net, error] = mlptrain(net, x, t, step); % Plot line from last point to new point. line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none'); iold = i; errold = error; % If new point off scale, redraw axes. if error > max, max = error; axis([0 iterations min max+1]); end; if error < min min = error/10; axis([0 iterations min max+1]); end % Plot circle to show error of last iteration plot(i, error, 'ro', 'EraseMode', 'none'); drawnow; % Force redraw end; save mlptrain.net net zoom on; title(['Training complete. Final error=', num2str(error)]); end; elseif strcmp(action, 'close'), % Close all the figures we have created close(DEMTRAIN_FIG); for n = 1:NUM_DEMTRAIN_RES_FIGS if ishandle(DEMTRAIN_RES_FIGS(n)) close(DEMTRAIN_RES_FIGS(n)); end end elseif strcmp(action, 'classify'), if fopen('mlptrain.net') == -1, errordlg('You have not yet trained the network.', 'Error'); else hndlList = get(gcf, 'UserData'); filename = get(hndlList(3), 'String'); path = get(hndlList(4), 'String'); [x,t,nin,nout,ndata] = datread([path filename]); load mlptrain.net net -mat y = mlpfwd(net, x); % Save results figure so that it can be closed later NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t); end; elseif strcmp(action, 'predict'), if fopen('mlptrain.net') == -1, errordlg('You have not yet trained the network.', 'Error'); else hndlList = get(gcf, 'UserData'); filename = get(hndlList(3), 'String'); path = get(hndlList(4), 'String'); [x,t,nin,nout,ndata] = datread([path filename]); load mlptrain.net net -mat y = mlpfwd(net, x); for i = 1:size(y,2), % Save results figure so that it can be closed later NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure; hold on; title(['Output no ' num2str(i)]); plot([0 1], [0 1], 'r:'); plot(y(:,i),t(:,i), 'o'); hold off; end; end; end;