view toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function demtrain(action);
%DEMTRAIN Demonstrate training of MLP network.
%
%	Description
%	DEMTRAIN brings up a simple GUI to show the training of an MLP
%	network on classification and regression problems.  The user should
%	load in a dataset (which should be in Netlab format: see  DATREAD),
%	select the output activation function, the  number of cycles and
%	hidden units and then train the network. The scaled conjugate
%	gradient algorithm is used. A graph shows the evolution of the error:
%	the value is shown  MAX(CEIL(ITERATIONS / 50), 5) cycles.
%
%	Once the network is trained, it is saved to the file MLPTRAIN.NET.
%	The results can then be viewed as a confusion matrix (for
%	classification problems) or a plot of output versus target (for
%	regression problems).
%
%	See also
%	CONFMAT, DATREAD, MLP, NETOPT, SCG
%

%	Copyright (c) Ian T Nabney (1996-2001)

% If run without parameters, initialise gui.
if nargin<1,
  action='initialise';
end;

% Global variable to reference GUI figure
global DEMTRAIN_FIG
% Global array to reference sub-figures for results plots
global DEMTRAIN_RES_FIGS
global NUM_DEMTRAIN_RES_FIGS

if strcmp(action,'initialise'),

  file = '';
  path = '.';
  
  % Create FIGURE
  fig = figure( ...
	'Name', 'Netlab Demo', ...
	'NumberTitle', 'off', ...
	'Menubar', 'none', ...
	'Color', [0.7529 0.7529 0.7529], ...
	'Visible', 'on');
  % Initialise the globals
  DEMTRAIN_FIG = fig;
  DEMTRAIN_RES_FIGS = 0;
  NUM_DEMTRAIN_RES_FIGS = 0;

  % Create GROUP for buttons
  uicontrol(fig, ...
	'Style', 'frame', ...
	'Units', 'normalized', ...
	'Position', [0.03 0.08 0.94 0.22], ...
	'BackgroundColor', [0.5 0.5 0.5]);

  % Create MAIN axis
  hMain = axes( ...
	'Units', 'normalized', ...
	'Position', [0.10 0.5 0.80 0.40], ...
	'XColor', [0 0 0], ...
	'YColor', [0 0 0], ...
	'Visible', 'on');

  % Create static text for FILENAME and PATH
  hFilename = uicontrol(fig, ...
	'Style', 'text', ...
	'Units', 'normalized', ...
	'BackgroundColor', [0.7529 0.7529 0.7529], ...
	'Position', [0.05 0.32 0.90 0.05], ...
	'HorizontalAlignment', 'center', ...
	'String', 'Please load data file.', ...
	'Visible', 'on');
  hPath = uicontrol(fig, ...
	'Style', 'text', ...
	'Units', 'normalized', ...
	'BackgroundColor', [0.7529 0.7529 0.7529], ...
	'Position', [0.05 0.37 0.90 0.05], ...
	'HorizontalAlignment', 'center', ...
	'String', '', ...
	'Visible', 'on');

  % Create NO OF HIDDEN UNITS slider and text
  hSliderText = uicontrol(fig, ...
	'Style', 'text', ...
	'BackgroundColor', [0.5 0.5 0.5], ...
	'Units', 'normalized', ...
	'Position', [0.27 0.12 0.17 0.04], ...
	'HorizontalAlignment', 'right', ...
	'String', 'Hidden Units: 5');
  hSlider = uicontrol(fig, ...
	'Style', 'slider', ...
	'Units', 'normalized', ...
	'Position', [0.45 0.12 0.26 0.04], ...
	'String', 'Slider', ...
	'Min', 1, 'Max', 25, ...
	'Value', 5, ...
	'Callback', 'demtrain slider_moved');

  % Create ITERATIONS slider and text
  hIterationsText = uicontrol(fig, ...
	'Style', 'text', ...
	'BackgroundColor', [0.5 0.5 0.5], ...
	'Units', 'normalized', ...
	'Position', [0.27 0.21 0.17 0.04], ...
	'HorizontalAlignment', 'right', ...
	'String', 'Iterations: 50');
  hIterations = uicontrol(fig, ...
	'Style', 'slider', ...
	'Units', 'normalized', ...
	'Position', [0.45 0.21 0.26 0.04], ...
	'String', 'Slider', ...
	'Min', 10, 'Max', 500, ...
	'Value', 50, ...
	'Callback', 'demtrain iterations_moved');

  % Create ACTIVATION FUNCTION popup and text
  uicontrol(fig, ...
	'Style', 'text', ...
	'BackgroundColor', [0.5 0.5 0.5], ...
	'Units', 'normalized', ...
	'Position', [0.05 0.20 0.20 0.04], ...
	'HorizontalAlignment', 'center', ...
	'String', 'Activation Function:');
  hPopup = uicontrol(fig, ...
	'Style', 'popup', ...
	'Units', 'normalized', ...
	'Position' , [0.05 0.10 0.20 0.08], ...
	'String', 'Linear|Logistic|Softmax', ...
	'Callback', '');

  % Create MENU
  hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', '');
  uimenu(hMenu1, 'Label', 'Select training data file', ...
	'Callback', 'demtrain get_ip_file');
  hMenu2 = uimenu('Label', 'Show Results...', 'Callback', '');
  uimenu(hMenu2, 'Label', 'Show classification results', ...
	'Callback', 'demtrain classify');
  uimenu(hMenu2, 'Label', 'Show regression results', ...
	'Callback', 'demtrain predict');
  
  % Create START button
  hStart = uicontrol(fig, ...
	'Units', 'normalized', ...
	'Position' , [0.75 0.2 0.20 0.08], ...
	'String', 'Start Training', ...
	'Enable', 'off',...
	'Callback', 'demtrain start');

  % Create CLOSE button
  uicontrol(fig, ...
	'Units', 'normalized', ...
	'Position' , [0.75 0.1 0.20 0.08], ...
	'String', 'Close', ...
	'Callback', 'demtrain close');

  % Save handles of important UI objects
  hndlList = [hSlider hSliderText hFilename hPath hPopup ...
      hIterations hIterationsText hStart];
  set(fig, 'UserData', hndlList);
  % Hide window from command line
  set(fig, 'HandleVisibility', 'callback');

  
elseif strcmp(action, 'slider_moved'),

  % Slider has been moved. 
  
  hndlList = get(gcf, 'UserData');
  hSlider = hndlList(1);
  hSliderText = hndlList(2);

  val = get(hSlider, 'Value');
  if rem(val, 1) < 0.5,  % Force up and down arrows to work!
	val = ceil(val);
  else
	val = floor(val);
  end;
  set(hSlider, 'Value', val);
  set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]);

  
elseif strcmp(action, 'iterations_moved'),

  % Slider has been moved. 
  
  hndlList = get(gcf, 'UserData');
  hSlider = hndlList(6);
  hSliderText = hndlList(7);

  val = get(hSlider, 'Value');
  set(hSliderText, 'String', ['Iterations: ' int2str(val)]);

elseif strcmp(action, 'get_ip_file'),

  % Get data file button pressed.
  
  hndlList = get(gcf, 'UserData');

  [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50);

  if strcmp(file, '') | file == 0,
    set(hndlList(3), 'String', 'No data file loaded.');
    set(hndlList(4), 'String', '');
  else
    set(hndlList(3), 'String', file);
    set(hndlList(4), 'String', path);
  end;
  
  % Enable training button
  set(hndlList(8), 'Enable', 'on');

  set(gcf, 'UserData', hndlList);
  
elseif strcmp(action, 'start'),

  % Start training
  
  % Get handles of and values from UI objects
  hndlList = get(gcf, 'UserData');
  hSlider = hndlList(1); % 				No of hidden units
  hIterations = hndlList(6); 
  iterations = get(hIterations, 'Value');
  
  hFilename = hndlList(3);	% 			Data file name
  filename = get(hFilename, 'String');

  hPath = hndlList(4);	% 				Data file path
  path = get(hPath, 'String');

  hPopup = hndlList(5);		% 			Activation function
  if get(hPopup, 'Value') == 1,
	act_fn = 'linear';
  elseif get(hPopup, 'Value') == 2,
	act_fn = 'logistic';
  else
	act_fn = 'softmax';
  end;
  nhidden = get(hSlider, 'Value');

  % Check data file exists
  if fopen([path '/' filename]) == -1,
	errordlg('Training data file has not been selected.', 'Error');
  else
	% Load data file
	[x,t,nin,nout,ndata] = datread([path filename]);
	
	% Call MLPTRAIN function repeatedly, while drawing training graph.
	figure(DEMTRAIN_FIG);
	hold on;
	
	title('Training - please wait.');
	
	% Create net and find initial error
	net = mlp(size(x, 2), nhidden, size(t, 2), act_fn);
	% Initialise network with inverse variance of 10
	net = mlpinit(net, 10);
	error = mlperr(net, x, t);
	% Work out reporting step: should be sufficiently big to let training
	% algorithm have a chance
	step = max(ceil(iterations / 50), 5);

	% Refresh and rescale axis.
	cla;
	max = error;
	min = max/10;
	set(gca, 'YScale', 'log');
	ylabel('log Error');
	xlabel('No. iterations');
	axis([0 iterations min max+1]);
	iold = 0;
	errold = error;
	% Plot circle to show error of last iteration
	% Setting erase mode to none prevents screen flashing during 
	% training
	plot(0, error, 'ro', 'EraseMode', 'none');
	hold on
	drawnow; % Force redraw
	for i = step-1:step:iterations,
	  [net, error] = mlptrain(net, x, t, step);
	  % Plot line from last point to new point.
	  line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none');
	  iold = i;
	  errold = error;
	  
	  % If new point off scale, redraw axes.
	  if error > max,
	    max = error;
	    axis([0 iterations min max+1]);
	  end;
	  if error < min
	    min = error/10;
	    axis([0 iterations min max+1]);
	  end
	  % Plot circle to show error of last iteration
	  plot(i, error, 'ro', 'EraseMode', 'none');
	  drawnow; % Force redraw
	end;
	save mlptrain.net net
	zoom on;

	title(['Training complete. Final error=', num2str(error)]);
	
  end;

elseif strcmp(action, 'close'),
  
  % Close all the figures we have created
  close(DEMTRAIN_FIG);
  for n = 1:NUM_DEMTRAIN_RES_FIGS
    if ishandle(DEMTRAIN_RES_FIGS(n))
      close(DEMTRAIN_RES_FIGS(n));
    end
  end

elseif strcmp(action, 'classify'),
  
  if fopen('mlptrain.net') == -1,
	errordlg('You have not yet trained the network.', 'Error');
  else
  
	hndlList = get(gcf, 'UserData');
	filename = get(hndlList(3), 'String');
	path = get(hndlList(4), 'String');
	[x,t,nin,nout,ndata] = datread([path filename]);
	load mlptrain.net net -mat
	y = mlpfwd(net, x);
	
	% Save results figure so that it can be closed later
	NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
	DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t);
	
  end;

elseif strcmp(action, 'predict'),
  
  if fopen('mlptrain.net') == -1,
	errordlg('You have not yet trained the network.', 'Error');
  else
  
	hndlList = get(gcf, 'UserData');
	filename = get(hndlList(3), 'String');
	path = get(hndlList(4), 'String');
	[x,t,nin,nout,ndata] = datread([path filename]);
	load mlptrain.net net -mat
	y = mlpfwd(net, x);
	
	for i = 1:size(y,2),
	  % Save results figure so that it can be closed later
	  NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
	  DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure;
	  hold on;
	  title(['Output no ' num2str(i)]);
	  plot([0 1], [0 1], 'r:');
	  plot(y(:,i),t(:,i), 'o');
	  hold off;
	end;
  end;
	
end;