comparison toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function demtrain(action);
2 %DEMTRAIN Demonstrate training of MLP network.
3 %
4 % Description
5 % DEMTRAIN brings up a simple GUI to show the training of an MLP
6 % network on classification and regression problems. The user should
7 % load in a dataset (which should be in Netlab format: see DATREAD),
8 % select the output activation function, the number of cycles and
9 % hidden units and then train the network. The scaled conjugate
10 % gradient algorithm is used. A graph shows the evolution of the error:
11 % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles.
12 %
13 % Once the network is trained, it is saved to the file MLPTRAIN.NET.
14 % The results can then be viewed as a confusion matrix (for
15 % classification problems) or a plot of output versus target (for
16 % regression problems).
17 %
18 % See also
19 % CONFMAT, DATREAD, MLP, NETOPT, SCG
20 %
21
22 % Copyright (c) Ian T Nabney (1996-2001)
23
24 % If run without parameters, initialise gui.
25 if nargin<1,
26 action='initialise';
27 end;
28
29 % Global variable to reference GUI figure
30 global DEMTRAIN_FIG
31 % Global array to reference sub-figures for results plots
32 global DEMTRAIN_RES_FIGS
33 global NUM_DEMTRAIN_RES_FIGS
34
35 if strcmp(action,'initialise'),
36
37 file = '';
38 path = '.';
39
40 % Create FIGURE
41 fig = figure( ...
42 'Name', 'Netlab Demo', ...
43 'NumberTitle', 'off', ...
44 'Menubar', 'none', ...
45 'Color', [0.7529 0.7529 0.7529], ...
46 'Visible', 'on');
47 % Initialise the globals
48 DEMTRAIN_FIG = fig;
49 DEMTRAIN_RES_FIGS = 0;
50 NUM_DEMTRAIN_RES_FIGS = 0;
51
52 % Create GROUP for buttons
53 uicontrol(fig, ...
54 'Style', 'frame', ...
55 'Units', 'normalized', ...
56 'Position', [0.03 0.08 0.94 0.22], ...
57 'BackgroundColor', [0.5 0.5 0.5]);
58
59 % Create MAIN axis
60 hMain = axes( ...
61 'Units', 'normalized', ...
62 'Position', [0.10 0.5 0.80 0.40], ...
63 'XColor', [0 0 0], ...
64 'YColor', [0 0 0], ...
65 'Visible', 'on');
66
67 % Create static text for FILENAME and PATH
68 hFilename = uicontrol(fig, ...
69 'Style', 'text', ...
70 'Units', 'normalized', ...
71 'BackgroundColor', [0.7529 0.7529 0.7529], ...
72 'Position', [0.05 0.32 0.90 0.05], ...
73 'HorizontalAlignment', 'center', ...
74 'String', 'Please load data file.', ...
75 'Visible', 'on');
76 hPath = uicontrol(fig, ...
77 'Style', 'text', ...
78 'Units', 'normalized', ...
79 'BackgroundColor', [0.7529 0.7529 0.7529], ...
80 'Position', [0.05 0.37 0.90 0.05], ...
81 'HorizontalAlignment', 'center', ...
82 'String', '', ...
83 'Visible', 'on');
84
85 % Create NO OF HIDDEN UNITS slider and text
86 hSliderText = uicontrol(fig, ...
87 'Style', 'text', ...
88 'BackgroundColor', [0.5 0.5 0.5], ...
89 'Units', 'normalized', ...
90 'Position', [0.27 0.12 0.17 0.04], ...
91 'HorizontalAlignment', 'right', ...
92 'String', 'Hidden Units: 5');
93 hSlider = uicontrol(fig, ...
94 'Style', 'slider', ...
95 'Units', 'normalized', ...
96 'Position', [0.45 0.12 0.26 0.04], ...
97 'String', 'Slider', ...
98 'Min', 1, 'Max', 25, ...
99 'Value', 5, ...
100 'Callback', 'demtrain slider_moved');
101
102 % Create ITERATIONS slider and text
103 hIterationsText = uicontrol(fig, ...
104 'Style', 'text', ...
105 'BackgroundColor', [0.5 0.5 0.5], ...
106 'Units', 'normalized', ...
107 'Position', [0.27 0.21 0.17 0.04], ...
108 'HorizontalAlignment', 'right', ...
109 'String', 'Iterations: 50');
110 hIterations = uicontrol(fig, ...
111 'Style', 'slider', ...
112 'Units', 'normalized', ...
113 'Position', [0.45 0.21 0.26 0.04], ...
114 'String', 'Slider', ...
115 'Min', 10, 'Max', 500, ...
116 'Value', 50, ...
117 'Callback', 'demtrain iterations_moved');
118
119 % Create ACTIVATION FUNCTION popup and text
120 uicontrol(fig, ...
121 'Style', 'text', ...
122 'BackgroundColor', [0.5 0.5 0.5], ...
123 'Units', 'normalized', ...
124 'Position', [0.05 0.20 0.20 0.04], ...
125 'HorizontalAlignment', 'center', ...
126 'String', 'Activation Function:');
127 hPopup = uicontrol(fig, ...
128 'Style', 'popup', ...
129 'Units', 'normalized', ...
130 'Position' , [0.05 0.10 0.20 0.08], ...
131 'String', 'Linear|Logistic|Softmax', ...
132 'Callback', '');
133
134 % Create MENU
135 hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', '');
136 uimenu(hMenu1, 'Label', 'Select training data file', ...
137 'Callback', 'demtrain get_ip_file');
138 hMenu2 = uimenu('Label', 'Show Results...', 'Callback', '');
139 uimenu(hMenu2, 'Label', 'Show classification results', ...
140 'Callback', 'demtrain classify');
141 uimenu(hMenu2, 'Label', 'Show regression results', ...
142 'Callback', 'demtrain predict');
143
144 % Create START button
145 hStart = uicontrol(fig, ...
146 'Units', 'normalized', ...
147 'Position' , [0.75 0.2 0.20 0.08], ...
148 'String', 'Start Training', ...
149 'Enable', 'off',...
150 'Callback', 'demtrain start');
151
152 % Create CLOSE button
153 uicontrol(fig, ...
154 'Units', 'normalized', ...
155 'Position' , [0.75 0.1 0.20 0.08], ...
156 'String', 'Close', ...
157 'Callback', 'demtrain close');
158
159 % Save handles of important UI objects
160 hndlList = [hSlider hSliderText hFilename hPath hPopup ...
161 hIterations hIterationsText hStart];
162 set(fig, 'UserData', hndlList);
163 % Hide window from command line
164 set(fig, 'HandleVisibility', 'callback');
165
166
167 elseif strcmp(action, 'slider_moved'),
168
169 % Slider has been moved.
170
171 hndlList = get(gcf, 'UserData');
172 hSlider = hndlList(1);
173 hSliderText = hndlList(2);
174
175 val = get(hSlider, 'Value');
176 if rem(val, 1) < 0.5, % Force up and down arrows to work!
177 val = ceil(val);
178 else
179 val = floor(val);
180 end;
181 set(hSlider, 'Value', val);
182 set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]);
183
184
185 elseif strcmp(action, 'iterations_moved'),
186
187 % Slider has been moved.
188
189 hndlList = get(gcf, 'UserData');
190 hSlider = hndlList(6);
191 hSliderText = hndlList(7);
192
193 val = get(hSlider, 'Value');
194 set(hSliderText, 'String', ['Iterations: ' int2str(val)]);
195
196 elseif strcmp(action, 'get_ip_file'),
197
198 % Get data file button pressed.
199
200 hndlList = get(gcf, 'UserData');
201
202 [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50);
203
204 if strcmp(file, '') | file == 0,
205 set(hndlList(3), 'String', 'No data file loaded.');
206 set(hndlList(4), 'String', '');
207 else
208 set(hndlList(3), 'String', file);
209 set(hndlList(4), 'String', path);
210 end;
211
212 % Enable training button
213 set(hndlList(8), 'Enable', 'on');
214
215 set(gcf, 'UserData', hndlList);
216
217 elseif strcmp(action, 'start'),
218
219 % Start training
220
221 % Get handles of and values from UI objects
222 hndlList = get(gcf, 'UserData');
223 hSlider = hndlList(1); % No of hidden units
224 hIterations = hndlList(6);
225 iterations = get(hIterations, 'Value');
226
227 hFilename = hndlList(3); % Data file name
228 filename = get(hFilename, 'String');
229
230 hPath = hndlList(4); % Data file path
231 path = get(hPath, 'String');
232
233 hPopup = hndlList(5); % Activation function
234 if get(hPopup, 'Value') == 1,
235 act_fn = 'linear';
236 elseif get(hPopup, 'Value') == 2,
237 act_fn = 'logistic';
238 else
239 act_fn = 'softmax';
240 end;
241 nhidden = get(hSlider, 'Value');
242
243 % Check data file exists
244 if fopen([path '/' filename]) == -1,
245 errordlg('Training data file has not been selected.', 'Error');
246 else
247 % Load data file
248 [x,t,nin,nout,ndata] = datread([path filename]);
249
250 % Call MLPTRAIN function repeatedly, while drawing training graph.
251 figure(DEMTRAIN_FIG);
252 hold on;
253
254 title('Training - please wait.');
255
256 % Create net and find initial error
257 net = mlp(size(x, 2), nhidden, size(t, 2), act_fn);
258 % Initialise network with inverse variance of 10
259 net = mlpinit(net, 10);
260 error = mlperr(net, x, t);
261 % Work out reporting step: should be sufficiently big to let training
262 % algorithm have a chance
263 step = max(ceil(iterations / 50), 5);
264
265 % Refresh and rescale axis.
266 cla;
267 max = error;
268 min = max/10;
269 set(gca, 'YScale', 'log');
270 ylabel('log Error');
271 xlabel('No. iterations');
272 axis([0 iterations min max+1]);
273 iold = 0;
274 errold = error;
275 % Plot circle to show error of last iteration
276 % Setting erase mode to none prevents screen flashing during
277 % training
278 plot(0, error, 'ro', 'EraseMode', 'none');
279 hold on
280 drawnow; % Force redraw
281 for i = step-1:step:iterations,
282 [net, error] = mlptrain(net, x, t, step);
283 % Plot line from last point to new point.
284 line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none');
285 iold = i;
286 errold = error;
287
288 % If new point off scale, redraw axes.
289 if error > max,
290 max = error;
291 axis([0 iterations min max+1]);
292 end;
293 if error < min
294 min = error/10;
295 axis([0 iterations min max+1]);
296 end
297 % Plot circle to show error of last iteration
298 plot(i, error, 'ro', 'EraseMode', 'none');
299 drawnow; % Force redraw
300 end;
301 save mlptrain.net net
302 zoom on;
303
304 title(['Training complete. Final error=', num2str(error)]);
305
306 end;
307
308 elseif strcmp(action, 'close'),
309
310 % Close all the figures we have created
311 close(DEMTRAIN_FIG);
312 for n = 1:NUM_DEMTRAIN_RES_FIGS
313 if ishandle(DEMTRAIN_RES_FIGS(n))
314 close(DEMTRAIN_RES_FIGS(n));
315 end
316 end
317
318 elseif strcmp(action, 'classify'),
319
320 if fopen('mlptrain.net') == -1,
321 errordlg('You have not yet trained the network.', 'Error');
322 else
323
324 hndlList = get(gcf, 'UserData');
325 filename = get(hndlList(3), 'String');
326 path = get(hndlList(4), 'String');
327 [x,t,nin,nout,ndata] = datread([path filename]);
328 load mlptrain.net net -mat
329 y = mlpfwd(net, x);
330
331 % Save results figure so that it can be closed later
332 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
333 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t);
334
335 end;
336
337 elseif strcmp(action, 'predict'),
338
339 if fopen('mlptrain.net') == -1,
340 errordlg('You have not yet trained the network.', 'Error');
341 else
342
343 hndlList = get(gcf, 'UserData');
344 filename = get(hndlList(3), 'String');
345 path = get(hndlList(4), 'String');
346 [x,t,nin,nout,ndata] = datread([path filename]);
347 load mlptrain.net net -mat
348 y = mlpfwd(net, x);
349
350 for i = 1:size(y,2),
351 % Save results figure so that it can be closed later
352 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
353 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure;
354 hold on;
355 title(['Output no ' num2str(i)]);
356 plot([0 1], [0 1], 'r:');
357 plot(y(:,i),t(:,i), 'o');
358 hold off;
359 end;
360 end;
361
362 end;