wolffd@0
|
1 function demtrain(action);
|
wolffd@0
|
2 %DEMTRAIN Demonstrate training of MLP network.
|
wolffd@0
|
3 %
|
wolffd@0
|
4 % Description
|
wolffd@0
|
5 % DEMTRAIN brings up a simple GUI to show the training of an MLP
|
wolffd@0
|
6 % network on classification and regression problems. The user should
|
wolffd@0
|
7 % load in a dataset (which should be in Netlab format: see DATREAD),
|
wolffd@0
|
8 % select the output activation function, the number of cycles and
|
wolffd@0
|
9 % hidden units and then train the network. The scaled conjugate
|
wolffd@0
|
10 % gradient algorithm is used. A graph shows the evolution of the error:
|
wolffd@0
|
11 % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles.
|
wolffd@0
|
12 %
|
wolffd@0
|
13 % Once the network is trained, it is saved to the file MLPTRAIN.NET.
|
wolffd@0
|
14 % The results can then be viewed as a confusion matrix (for
|
wolffd@0
|
15 % classification problems) or a plot of output versus target (for
|
wolffd@0
|
16 % regression problems).
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % See also
|
wolffd@0
|
19 % CONFMAT, DATREAD, MLP, NETOPT, SCG
|
wolffd@0
|
20 %
|
wolffd@0
|
21
|
wolffd@0
|
22 % Copyright (c) Ian T Nabney (1996-2001)
|
wolffd@0
|
23
|
wolffd@0
|
24 % If run without parameters, initialise gui.
|
wolffd@0
|
25 if nargin<1,
|
wolffd@0
|
26 action='initialise';
|
wolffd@0
|
27 end;
|
wolffd@0
|
28
|
wolffd@0
|
29 % Global variable to reference GUI figure
|
wolffd@0
|
30 global DEMTRAIN_FIG
|
wolffd@0
|
31 % Global array to reference sub-figures for results plots
|
wolffd@0
|
32 global DEMTRAIN_RES_FIGS
|
wolffd@0
|
33 global NUM_DEMTRAIN_RES_FIGS
|
wolffd@0
|
34
|
wolffd@0
|
35 if strcmp(action,'initialise'),
|
wolffd@0
|
36
|
wolffd@0
|
37 file = '';
|
wolffd@0
|
38 path = '.';
|
wolffd@0
|
39
|
wolffd@0
|
40 % Create FIGURE
|
wolffd@0
|
41 fig = figure( ...
|
wolffd@0
|
42 'Name', 'Netlab Demo', ...
|
wolffd@0
|
43 'NumberTitle', 'off', ...
|
wolffd@0
|
44 'Menubar', 'none', ...
|
wolffd@0
|
45 'Color', [0.7529 0.7529 0.7529], ...
|
wolffd@0
|
46 'Visible', 'on');
|
wolffd@0
|
47 % Initialise the globals
|
wolffd@0
|
48 DEMTRAIN_FIG = fig;
|
wolffd@0
|
49 DEMTRAIN_RES_FIGS = 0;
|
wolffd@0
|
50 NUM_DEMTRAIN_RES_FIGS = 0;
|
wolffd@0
|
51
|
wolffd@0
|
52 % Create GROUP for buttons
|
wolffd@0
|
53 uicontrol(fig, ...
|
wolffd@0
|
54 'Style', 'frame', ...
|
wolffd@0
|
55 'Units', 'normalized', ...
|
wolffd@0
|
56 'Position', [0.03 0.08 0.94 0.22], ...
|
wolffd@0
|
57 'BackgroundColor', [0.5 0.5 0.5]);
|
wolffd@0
|
58
|
wolffd@0
|
59 % Create MAIN axis
|
wolffd@0
|
60 hMain = axes( ...
|
wolffd@0
|
61 'Units', 'normalized', ...
|
wolffd@0
|
62 'Position', [0.10 0.5 0.80 0.40], ...
|
wolffd@0
|
63 'XColor', [0 0 0], ...
|
wolffd@0
|
64 'YColor', [0 0 0], ...
|
wolffd@0
|
65 'Visible', 'on');
|
wolffd@0
|
66
|
wolffd@0
|
67 % Create static text for FILENAME and PATH
|
wolffd@0
|
68 hFilename = uicontrol(fig, ...
|
wolffd@0
|
69 'Style', 'text', ...
|
wolffd@0
|
70 'Units', 'normalized', ...
|
wolffd@0
|
71 'BackgroundColor', [0.7529 0.7529 0.7529], ...
|
wolffd@0
|
72 'Position', [0.05 0.32 0.90 0.05], ...
|
wolffd@0
|
73 'HorizontalAlignment', 'center', ...
|
wolffd@0
|
74 'String', 'Please load data file.', ...
|
wolffd@0
|
75 'Visible', 'on');
|
wolffd@0
|
76 hPath = uicontrol(fig, ...
|
wolffd@0
|
77 'Style', 'text', ...
|
wolffd@0
|
78 'Units', 'normalized', ...
|
wolffd@0
|
79 'BackgroundColor', [0.7529 0.7529 0.7529], ...
|
wolffd@0
|
80 'Position', [0.05 0.37 0.90 0.05], ...
|
wolffd@0
|
81 'HorizontalAlignment', 'center', ...
|
wolffd@0
|
82 'String', '', ...
|
wolffd@0
|
83 'Visible', 'on');
|
wolffd@0
|
84
|
wolffd@0
|
85 % Create NO OF HIDDEN UNITS slider and text
|
wolffd@0
|
86 hSliderText = uicontrol(fig, ...
|
wolffd@0
|
87 'Style', 'text', ...
|
wolffd@0
|
88 'BackgroundColor', [0.5 0.5 0.5], ...
|
wolffd@0
|
89 'Units', 'normalized', ...
|
wolffd@0
|
90 'Position', [0.27 0.12 0.17 0.04], ...
|
wolffd@0
|
91 'HorizontalAlignment', 'right', ...
|
wolffd@0
|
92 'String', 'Hidden Units: 5');
|
wolffd@0
|
93 hSlider = uicontrol(fig, ...
|
wolffd@0
|
94 'Style', 'slider', ...
|
wolffd@0
|
95 'Units', 'normalized', ...
|
wolffd@0
|
96 'Position', [0.45 0.12 0.26 0.04], ...
|
wolffd@0
|
97 'String', 'Slider', ...
|
wolffd@0
|
98 'Min', 1, 'Max', 25, ...
|
wolffd@0
|
99 'Value', 5, ...
|
wolffd@0
|
100 'Callback', 'demtrain slider_moved');
|
wolffd@0
|
101
|
wolffd@0
|
102 % Create ITERATIONS slider and text
|
wolffd@0
|
103 hIterationsText = uicontrol(fig, ...
|
wolffd@0
|
104 'Style', 'text', ...
|
wolffd@0
|
105 'BackgroundColor', [0.5 0.5 0.5], ...
|
wolffd@0
|
106 'Units', 'normalized', ...
|
wolffd@0
|
107 'Position', [0.27 0.21 0.17 0.04], ...
|
wolffd@0
|
108 'HorizontalAlignment', 'right', ...
|
wolffd@0
|
109 'String', 'Iterations: 50');
|
wolffd@0
|
110 hIterations = uicontrol(fig, ...
|
wolffd@0
|
111 'Style', 'slider', ...
|
wolffd@0
|
112 'Units', 'normalized', ...
|
wolffd@0
|
113 'Position', [0.45 0.21 0.26 0.04], ...
|
wolffd@0
|
114 'String', 'Slider', ...
|
wolffd@0
|
115 'Min', 10, 'Max', 500, ...
|
wolffd@0
|
116 'Value', 50, ...
|
wolffd@0
|
117 'Callback', 'demtrain iterations_moved');
|
wolffd@0
|
118
|
wolffd@0
|
119 % Create ACTIVATION FUNCTION popup and text
|
wolffd@0
|
120 uicontrol(fig, ...
|
wolffd@0
|
121 'Style', 'text', ...
|
wolffd@0
|
122 'BackgroundColor', [0.5 0.5 0.5], ...
|
wolffd@0
|
123 'Units', 'normalized', ...
|
wolffd@0
|
124 'Position', [0.05 0.20 0.20 0.04], ...
|
wolffd@0
|
125 'HorizontalAlignment', 'center', ...
|
wolffd@0
|
126 'String', 'Activation Function:');
|
wolffd@0
|
127 hPopup = uicontrol(fig, ...
|
wolffd@0
|
128 'Style', 'popup', ...
|
wolffd@0
|
129 'Units', 'normalized', ...
|
wolffd@0
|
130 'Position' , [0.05 0.10 0.20 0.08], ...
|
wolffd@0
|
131 'String', 'Linear|Logistic|Softmax', ...
|
wolffd@0
|
132 'Callback', '');
|
wolffd@0
|
133
|
wolffd@0
|
134 % Create MENU
|
wolffd@0
|
135 hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', '');
|
wolffd@0
|
136 uimenu(hMenu1, 'Label', 'Select training data file', ...
|
wolffd@0
|
137 'Callback', 'demtrain get_ip_file');
|
wolffd@0
|
138 hMenu2 = uimenu('Label', 'Show Results...', 'Callback', '');
|
wolffd@0
|
139 uimenu(hMenu2, 'Label', 'Show classification results', ...
|
wolffd@0
|
140 'Callback', 'demtrain classify');
|
wolffd@0
|
141 uimenu(hMenu2, 'Label', 'Show regression results', ...
|
wolffd@0
|
142 'Callback', 'demtrain predict');
|
wolffd@0
|
143
|
wolffd@0
|
144 % Create START button
|
wolffd@0
|
145 hStart = uicontrol(fig, ...
|
wolffd@0
|
146 'Units', 'normalized', ...
|
wolffd@0
|
147 'Position' , [0.75 0.2 0.20 0.08], ...
|
wolffd@0
|
148 'String', 'Start Training', ...
|
wolffd@0
|
149 'Enable', 'off',...
|
wolffd@0
|
150 'Callback', 'demtrain start');
|
wolffd@0
|
151
|
wolffd@0
|
152 % Create CLOSE button
|
wolffd@0
|
153 uicontrol(fig, ...
|
wolffd@0
|
154 'Units', 'normalized', ...
|
wolffd@0
|
155 'Position' , [0.75 0.1 0.20 0.08], ...
|
wolffd@0
|
156 'String', 'Close', ...
|
wolffd@0
|
157 'Callback', 'demtrain close');
|
wolffd@0
|
158
|
wolffd@0
|
159 % Save handles of important UI objects
|
wolffd@0
|
160 hndlList = [hSlider hSliderText hFilename hPath hPopup ...
|
wolffd@0
|
161 hIterations hIterationsText hStart];
|
wolffd@0
|
162 set(fig, 'UserData', hndlList);
|
wolffd@0
|
163 % Hide window from command line
|
wolffd@0
|
164 set(fig, 'HandleVisibility', 'callback');
|
wolffd@0
|
165
|
wolffd@0
|
166
|
wolffd@0
|
167 elseif strcmp(action, 'slider_moved'),
|
wolffd@0
|
168
|
wolffd@0
|
169 % Slider has been moved.
|
wolffd@0
|
170
|
wolffd@0
|
171 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
172 hSlider = hndlList(1);
|
wolffd@0
|
173 hSliderText = hndlList(2);
|
wolffd@0
|
174
|
wolffd@0
|
175 val = get(hSlider, 'Value');
|
wolffd@0
|
176 if rem(val, 1) < 0.5, % Force up and down arrows to work!
|
wolffd@0
|
177 val = ceil(val);
|
wolffd@0
|
178 else
|
wolffd@0
|
179 val = floor(val);
|
wolffd@0
|
180 end;
|
wolffd@0
|
181 set(hSlider, 'Value', val);
|
wolffd@0
|
182 set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]);
|
wolffd@0
|
183
|
wolffd@0
|
184
|
wolffd@0
|
185 elseif strcmp(action, 'iterations_moved'),
|
wolffd@0
|
186
|
wolffd@0
|
187 % Slider has been moved.
|
wolffd@0
|
188
|
wolffd@0
|
189 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
190 hSlider = hndlList(6);
|
wolffd@0
|
191 hSliderText = hndlList(7);
|
wolffd@0
|
192
|
wolffd@0
|
193 val = get(hSlider, 'Value');
|
wolffd@0
|
194 set(hSliderText, 'String', ['Iterations: ' int2str(val)]);
|
wolffd@0
|
195
|
wolffd@0
|
196 elseif strcmp(action, 'get_ip_file'),
|
wolffd@0
|
197
|
wolffd@0
|
198 % Get data file button pressed.
|
wolffd@0
|
199
|
wolffd@0
|
200 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
201
|
wolffd@0
|
202 [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50);
|
wolffd@0
|
203
|
wolffd@0
|
204 if strcmp(file, '') | file == 0,
|
wolffd@0
|
205 set(hndlList(3), 'String', 'No data file loaded.');
|
wolffd@0
|
206 set(hndlList(4), 'String', '');
|
wolffd@0
|
207 else
|
wolffd@0
|
208 set(hndlList(3), 'String', file);
|
wolffd@0
|
209 set(hndlList(4), 'String', path);
|
wolffd@0
|
210 end;
|
wolffd@0
|
211
|
wolffd@0
|
212 % Enable training button
|
wolffd@0
|
213 set(hndlList(8), 'Enable', 'on');
|
wolffd@0
|
214
|
wolffd@0
|
215 set(gcf, 'UserData', hndlList);
|
wolffd@0
|
216
|
wolffd@0
|
217 elseif strcmp(action, 'start'),
|
wolffd@0
|
218
|
wolffd@0
|
219 % Start training
|
wolffd@0
|
220
|
wolffd@0
|
221 % Get handles of and values from UI objects
|
wolffd@0
|
222 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
223 hSlider = hndlList(1); % No of hidden units
|
wolffd@0
|
224 hIterations = hndlList(6);
|
wolffd@0
|
225 iterations = get(hIterations, 'Value');
|
wolffd@0
|
226
|
wolffd@0
|
227 hFilename = hndlList(3); % Data file name
|
wolffd@0
|
228 filename = get(hFilename, 'String');
|
wolffd@0
|
229
|
wolffd@0
|
230 hPath = hndlList(4); % Data file path
|
wolffd@0
|
231 path = get(hPath, 'String');
|
wolffd@0
|
232
|
wolffd@0
|
233 hPopup = hndlList(5); % Activation function
|
wolffd@0
|
234 if get(hPopup, 'Value') == 1,
|
wolffd@0
|
235 act_fn = 'linear';
|
wolffd@0
|
236 elseif get(hPopup, 'Value') == 2,
|
wolffd@0
|
237 act_fn = 'logistic';
|
wolffd@0
|
238 else
|
wolffd@0
|
239 act_fn = 'softmax';
|
wolffd@0
|
240 end;
|
wolffd@0
|
241 nhidden = get(hSlider, 'Value');
|
wolffd@0
|
242
|
wolffd@0
|
243 % Check data file exists
|
wolffd@0
|
244 if fopen([path '/' filename]) == -1,
|
wolffd@0
|
245 errordlg('Training data file has not been selected.', 'Error');
|
wolffd@0
|
246 else
|
wolffd@0
|
247 % Load data file
|
wolffd@0
|
248 [x,t,nin,nout,ndata] = datread([path filename]);
|
wolffd@0
|
249
|
wolffd@0
|
250 % Call MLPTRAIN function repeatedly, while drawing training graph.
|
wolffd@0
|
251 figure(DEMTRAIN_FIG);
|
wolffd@0
|
252 hold on;
|
wolffd@0
|
253
|
wolffd@0
|
254 title('Training - please wait.');
|
wolffd@0
|
255
|
wolffd@0
|
256 % Create net and find initial error
|
wolffd@0
|
257 net = mlp(size(x, 2), nhidden, size(t, 2), act_fn);
|
wolffd@0
|
258 % Initialise network with inverse variance of 10
|
wolffd@0
|
259 net = mlpinit(net, 10);
|
wolffd@0
|
260 error = mlperr(net, x, t);
|
wolffd@0
|
261 % Work out reporting step: should be sufficiently big to let training
|
wolffd@0
|
262 % algorithm have a chance
|
wolffd@0
|
263 step = max(ceil(iterations / 50), 5);
|
wolffd@0
|
264
|
wolffd@0
|
265 % Refresh and rescale axis.
|
wolffd@0
|
266 cla;
|
wolffd@0
|
267 max = error;
|
wolffd@0
|
268 min = max/10;
|
wolffd@0
|
269 set(gca, 'YScale', 'log');
|
wolffd@0
|
270 ylabel('log Error');
|
wolffd@0
|
271 xlabel('No. iterations');
|
wolffd@0
|
272 axis([0 iterations min max+1]);
|
wolffd@0
|
273 iold = 0;
|
wolffd@0
|
274 errold = error;
|
wolffd@0
|
275 % Plot circle to show error of last iteration
|
wolffd@0
|
276 % Setting erase mode to none prevents screen flashing during
|
wolffd@0
|
277 % training
|
wolffd@0
|
278 plot(0, error, 'ro', 'EraseMode', 'none');
|
wolffd@0
|
279 hold on
|
wolffd@0
|
280 drawnow; % Force redraw
|
wolffd@0
|
281 for i = step-1:step:iterations,
|
wolffd@0
|
282 [net, error] = mlptrain(net, x, t, step);
|
wolffd@0
|
283 % Plot line from last point to new point.
|
wolffd@0
|
284 line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none');
|
wolffd@0
|
285 iold = i;
|
wolffd@0
|
286 errold = error;
|
wolffd@0
|
287
|
wolffd@0
|
288 % If new point off scale, redraw axes.
|
wolffd@0
|
289 if error > max,
|
wolffd@0
|
290 max = error;
|
wolffd@0
|
291 axis([0 iterations min max+1]);
|
wolffd@0
|
292 end;
|
wolffd@0
|
293 if error < min
|
wolffd@0
|
294 min = error/10;
|
wolffd@0
|
295 axis([0 iterations min max+1]);
|
wolffd@0
|
296 end
|
wolffd@0
|
297 % Plot circle to show error of last iteration
|
wolffd@0
|
298 plot(i, error, 'ro', 'EraseMode', 'none');
|
wolffd@0
|
299 drawnow; % Force redraw
|
wolffd@0
|
300 end;
|
wolffd@0
|
301 save mlptrain.net net
|
wolffd@0
|
302 zoom on;
|
wolffd@0
|
303
|
wolffd@0
|
304 title(['Training complete. Final error=', num2str(error)]);
|
wolffd@0
|
305
|
wolffd@0
|
306 end;
|
wolffd@0
|
307
|
wolffd@0
|
308 elseif strcmp(action, 'close'),
|
wolffd@0
|
309
|
wolffd@0
|
310 % Close all the figures we have created
|
wolffd@0
|
311 close(DEMTRAIN_FIG);
|
wolffd@0
|
312 for n = 1:NUM_DEMTRAIN_RES_FIGS
|
wolffd@0
|
313 if ishandle(DEMTRAIN_RES_FIGS(n))
|
wolffd@0
|
314 close(DEMTRAIN_RES_FIGS(n));
|
wolffd@0
|
315 end
|
wolffd@0
|
316 end
|
wolffd@0
|
317
|
wolffd@0
|
318 elseif strcmp(action, 'classify'),
|
wolffd@0
|
319
|
wolffd@0
|
320 if fopen('mlptrain.net') == -1,
|
wolffd@0
|
321 errordlg('You have not yet trained the network.', 'Error');
|
wolffd@0
|
322 else
|
wolffd@0
|
323
|
wolffd@0
|
324 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
325 filename = get(hndlList(3), 'String');
|
wolffd@0
|
326 path = get(hndlList(4), 'String');
|
wolffd@0
|
327 [x,t,nin,nout,ndata] = datread([path filename]);
|
wolffd@0
|
328 load mlptrain.net net -mat
|
wolffd@0
|
329 y = mlpfwd(net, x);
|
wolffd@0
|
330
|
wolffd@0
|
331 % Save results figure so that it can be closed later
|
wolffd@0
|
332 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
|
wolffd@0
|
333 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t);
|
wolffd@0
|
334
|
wolffd@0
|
335 end;
|
wolffd@0
|
336
|
wolffd@0
|
337 elseif strcmp(action, 'predict'),
|
wolffd@0
|
338
|
wolffd@0
|
339 if fopen('mlptrain.net') == -1,
|
wolffd@0
|
340 errordlg('You have not yet trained the network.', 'Error');
|
wolffd@0
|
341 else
|
wolffd@0
|
342
|
wolffd@0
|
343 hndlList = get(gcf, 'UserData');
|
wolffd@0
|
344 filename = get(hndlList(3), 'String');
|
wolffd@0
|
345 path = get(hndlList(4), 'String');
|
wolffd@0
|
346 [x,t,nin,nout,ndata] = datread([path filename]);
|
wolffd@0
|
347 load mlptrain.net net -mat
|
wolffd@0
|
348 y = mlpfwd(net, x);
|
wolffd@0
|
349
|
wolffd@0
|
350 for i = 1:size(y,2),
|
wolffd@0
|
351 % Save results figure so that it can be closed later
|
wolffd@0
|
352 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1;
|
wolffd@0
|
353 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure;
|
wolffd@0
|
354 hold on;
|
wolffd@0
|
355 title(['Output no ' num2str(i)]);
|
wolffd@0
|
356 plot([0 1], [0 1], 'r:');
|
wolffd@0
|
357 plot(y(:,i),t(:,i), 'o');
|
wolffd@0
|
358 hold off;
|
wolffd@0
|
359 end;
|
wolffd@0
|
360 end;
|
wolffd@0
|
361
|
wolffd@0
|
362 end; |