Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/netlab3.3/demtrain.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function demtrain(action); | |
2 %DEMTRAIN Demonstrate training of MLP network. | |
3 % | |
4 % Description | |
5 % DEMTRAIN brings up a simple GUI to show the training of an MLP | |
6 % network on classification and regression problems. The user should | |
7 % load in a dataset (which should be in Netlab format: see DATREAD), | |
8 % select the output activation function, the number of cycles and | |
9 % hidden units and then train the network. The scaled conjugate | |
10 % gradient algorithm is used. A graph shows the evolution of the error: | |
11 % the value is shown MAX(CEIL(ITERATIONS / 50), 5) cycles. | |
12 % | |
13 % Once the network is trained, it is saved to the file MLPTRAIN.NET. | |
14 % The results can then be viewed as a confusion matrix (for | |
15 % classification problems) or a plot of output versus target (for | |
16 % regression problems). | |
17 % | |
18 % See also | |
19 % CONFMAT, DATREAD, MLP, NETOPT, SCG | |
20 % | |
21 | |
22 % Copyright (c) Ian T Nabney (1996-2001) | |
23 | |
24 % If run without parameters, initialise gui. | |
25 if nargin<1, | |
26 action='initialise'; | |
27 end; | |
28 | |
29 % Global variable to reference GUI figure | |
30 global DEMTRAIN_FIG | |
31 % Global array to reference sub-figures for results plots | |
32 global DEMTRAIN_RES_FIGS | |
33 global NUM_DEMTRAIN_RES_FIGS | |
34 | |
35 if strcmp(action,'initialise'), | |
36 | |
37 file = ''; | |
38 path = '.'; | |
39 | |
40 % Create FIGURE | |
41 fig = figure( ... | |
42 'Name', 'Netlab Demo', ... | |
43 'NumberTitle', 'off', ... | |
44 'Menubar', 'none', ... | |
45 'Color', [0.7529 0.7529 0.7529], ... | |
46 'Visible', 'on'); | |
47 % Initialise the globals | |
48 DEMTRAIN_FIG = fig; | |
49 DEMTRAIN_RES_FIGS = 0; | |
50 NUM_DEMTRAIN_RES_FIGS = 0; | |
51 | |
52 % Create GROUP for buttons | |
53 uicontrol(fig, ... | |
54 'Style', 'frame', ... | |
55 'Units', 'normalized', ... | |
56 'Position', [0.03 0.08 0.94 0.22], ... | |
57 'BackgroundColor', [0.5 0.5 0.5]); | |
58 | |
59 % Create MAIN axis | |
60 hMain = axes( ... | |
61 'Units', 'normalized', ... | |
62 'Position', [0.10 0.5 0.80 0.40], ... | |
63 'XColor', [0 0 0], ... | |
64 'YColor', [0 0 0], ... | |
65 'Visible', 'on'); | |
66 | |
67 % Create static text for FILENAME and PATH | |
68 hFilename = uicontrol(fig, ... | |
69 'Style', 'text', ... | |
70 'Units', 'normalized', ... | |
71 'BackgroundColor', [0.7529 0.7529 0.7529], ... | |
72 'Position', [0.05 0.32 0.90 0.05], ... | |
73 'HorizontalAlignment', 'center', ... | |
74 'String', 'Please load data file.', ... | |
75 'Visible', 'on'); | |
76 hPath = uicontrol(fig, ... | |
77 'Style', 'text', ... | |
78 'Units', 'normalized', ... | |
79 'BackgroundColor', [0.7529 0.7529 0.7529], ... | |
80 'Position', [0.05 0.37 0.90 0.05], ... | |
81 'HorizontalAlignment', 'center', ... | |
82 'String', '', ... | |
83 'Visible', 'on'); | |
84 | |
85 % Create NO OF HIDDEN UNITS slider and text | |
86 hSliderText = uicontrol(fig, ... | |
87 'Style', 'text', ... | |
88 'BackgroundColor', [0.5 0.5 0.5], ... | |
89 'Units', 'normalized', ... | |
90 'Position', [0.27 0.12 0.17 0.04], ... | |
91 'HorizontalAlignment', 'right', ... | |
92 'String', 'Hidden Units: 5'); | |
93 hSlider = uicontrol(fig, ... | |
94 'Style', 'slider', ... | |
95 'Units', 'normalized', ... | |
96 'Position', [0.45 0.12 0.26 0.04], ... | |
97 'String', 'Slider', ... | |
98 'Min', 1, 'Max', 25, ... | |
99 'Value', 5, ... | |
100 'Callback', 'demtrain slider_moved'); | |
101 | |
102 % Create ITERATIONS slider and text | |
103 hIterationsText = uicontrol(fig, ... | |
104 'Style', 'text', ... | |
105 'BackgroundColor', [0.5 0.5 0.5], ... | |
106 'Units', 'normalized', ... | |
107 'Position', [0.27 0.21 0.17 0.04], ... | |
108 'HorizontalAlignment', 'right', ... | |
109 'String', 'Iterations: 50'); | |
110 hIterations = uicontrol(fig, ... | |
111 'Style', 'slider', ... | |
112 'Units', 'normalized', ... | |
113 'Position', [0.45 0.21 0.26 0.04], ... | |
114 'String', 'Slider', ... | |
115 'Min', 10, 'Max', 500, ... | |
116 'Value', 50, ... | |
117 'Callback', 'demtrain iterations_moved'); | |
118 | |
119 % Create ACTIVATION FUNCTION popup and text | |
120 uicontrol(fig, ... | |
121 'Style', 'text', ... | |
122 'BackgroundColor', [0.5 0.5 0.5], ... | |
123 'Units', 'normalized', ... | |
124 'Position', [0.05 0.20 0.20 0.04], ... | |
125 'HorizontalAlignment', 'center', ... | |
126 'String', 'Activation Function:'); | |
127 hPopup = uicontrol(fig, ... | |
128 'Style', 'popup', ... | |
129 'Units', 'normalized', ... | |
130 'Position' , [0.05 0.10 0.20 0.08], ... | |
131 'String', 'Linear|Logistic|Softmax', ... | |
132 'Callback', ''); | |
133 | |
134 % Create MENU | |
135 hMenu1 = uimenu('Label', 'Load Data file...', 'Callback', ''); | |
136 uimenu(hMenu1, 'Label', 'Select training data file', ... | |
137 'Callback', 'demtrain get_ip_file'); | |
138 hMenu2 = uimenu('Label', 'Show Results...', 'Callback', ''); | |
139 uimenu(hMenu2, 'Label', 'Show classification results', ... | |
140 'Callback', 'demtrain classify'); | |
141 uimenu(hMenu2, 'Label', 'Show regression results', ... | |
142 'Callback', 'demtrain predict'); | |
143 | |
144 % Create START button | |
145 hStart = uicontrol(fig, ... | |
146 'Units', 'normalized', ... | |
147 'Position' , [0.75 0.2 0.20 0.08], ... | |
148 'String', 'Start Training', ... | |
149 'Enable', 'off',... | |
150 'Callback', 'demtrain start'); | |
151 | |
152 % Create CLOSE button | |
153 uicontrol(fig, ... | |
154 'Units', 'normalized', ... | |
155 'Position' , [0.75 0.1 0.20 0.08], ... | |
156 'String', 'Close', ... | |
157 'Callback', 'demtrain close'); | |
158 | |
159 % Save handles of important UI objects | |
160 hndlList = [hSlider hSliderText hFilename hPath hPopup ... | |
161 hIterations hIterationsText hStart]; | |
162 set(fig, 'UserData', hndlList); | |
163 % Hide window from command line | |
164 set(fig, 'HandleVisibility', 'callback'); | |
165 | |
166 | |
167 elseif strcmp(action, 'slider_moved'), | |
168 | |
169 % Slider has been moved. | |
170 | |
171 hndlList = get(gcf, 'UserData'); | |
172 hSlider = hndlList(1); | |
173 hSliderText = hndlList(2); | |
174 | |
175 val = get(hSlider, 'Value'); | |
176 if rem(val, 1) < 0.5, % Force up and down arrows to work! | |
177 val = ceil(val); | |
178 else | |
179 val = floor(val); | |
180 end; | |
181 set(hSlider, 'Value', val); | |
182 set(hSliderText, 'String', ['Hidden Units: ' int2str(val)]); | |
183 | |
184 | |
185 elseif strcmp(action, 'iterations_moved'), | |
186 | |
187 % Slider has been moved. | |
188 | |
189 hndlList = get(gcf, 'UserData'); | |
190 hSlider = hndlList(6); | |
191 hSliderText = hndlList(7); | |
192 | |
193 val = get(hSlider, 'Value'); | |
194 set(hSliderText, 'String', ['Iterations: ' int2str(val)]); | |
195 | |
196 elseif strcmp(action, 'get_ip_file'), | |
197 | |
198 % Get data file button pressed. | |
199 | |
200 hndlList = get(gcf, 'UserData'); | |
201 | |
202 [file, path] = uigetfile('*.dat', 'Get Data File', 50, 50); | |
203 | |
204 if strcmp(file, '') | file == 0, | |
205 set(hndlList(3), 'String', 'No data file loaded.'); | |
206 set(hndlList(4), 'String', ''); | |
207 else | |
208 set(hndlList(3), 'String', file); | |
209 set(hndlList(4), 'String', path); | |
210 end; | |
211 | |
212 % Enable training button | |
213 set(hndlList(8), 'Enable', 'on'); | |
214 | |
215 set(gcf, 'UserData', hndlList); | |
216 | |
217 elseif strcmp(action, 'start'), | |
218 | |
219 % Start training | |
220 | |
221 % Get handles of and values from UI objects | |
222 hndlList = get(gcf, 'UserData'); | |
223 hSlider = hndlList(1); % No of hidden units | |
224 hIterations = hndlList(6); | |
225 iterations = get(hIterations, 'Value'); | |
226 | |
227 hFilename = hndlList(3); % Data file name | |
228 filename = get(hFilename, 'String'); | |
229 | |
230 hPath = hndlList(4); % Data file path | |
231 path = get(hPath, 'String'); | |
232 | |
233 hPopup = hndlList(5); % Activation function | |
234 if get(hPopup, 'Value') == 1, | |
235 act_fn = 'linear'; | |
236 elseif get(hPopup, 'Value') == 2, | |
237 act_fn = 'logistic'; | |
238 else | |
239 act_fn = 'softmax'; | |
240 end; | |
241 nhidden = get(hSlider, 'Value'); | |
242 | |
243 % Check data file exists | |
244 if fopen([path '/' filename]) == -1, | |
245 errordlg('Training data file has not been selected.', 'Error'); | |
246 else | |
247 % Load data file | |
248 [x,t,nin,nout,ndata] = datread([path filename]); | |
249 | |
250 % Call MLPTRAIN function repeatedly, while drawing training graph. | |
251 figure(DEMTRAIN_FIG); | |
252 hold on; | |
253 | |
254 title('Training - please wait.'); | |
255 | |
256 % Create net and find initial error | |
257 net = mlp(size(x, 2), nhidden, size(t, 2), act_fn); | |
258 % Initialise network with inverse variance of 10 | |
259 net = mlpinit(net, 10); | |
260 error = mlperr(net, x, t); | |
261 % Work out reporting step: should be sufficiently big to let training | |
262 % algorithm have a chance | |
263 step = max(ceil(iterations / 50), 5); | |
264 | |
265 % Refresh and rescale axis. | |
266 cla; | |
267 max = error; | |
268 min = max/10; | |
269 set(gca, 'YScale', 'log'); | |
270 ylabel('log Error'); | |
271 xlabel('No. iterations'); | |
272 axis([0 iterations min max+1]); | |
273 iold = 0; | |
274 errold = error; | |
275 % Plot circle to show error of last iteration | |
276 % Setting erase mode to none prevents screen flashing during | |
277 % training | |
278 plot(0, error, 'ro', 'EraseMode', 'none'); | |
279 hold on | |
280 drawnow; % Force redraw | |
281 for i = step-1:step:iterations, | |
282 [net, error] = mlptrain(net, x, t, step); | |
283 % Plot line from last point to new point. | |
284 line([iold i], [errold error], 'Color', 'r', 'EraseMode', 'none'); | |
285 iold = i; | |
286 errold = error; | |
287 | |
288 % If new point off scale, redraw axes. | |
289 if error > max, | |
290 max = error; | |
291 axis([0 iterations min max+1]); | |
292 end; | |
293 if error < min | |
294 min = error/10; | |
295 axis([0 iterations min max+1]); | |
296 end | |
297 % Plot circle to show error of last iteration | |
298 plot(i, error, 'ro', 'EraseMode', 'none'); | |
299 drawnow; % Force redraw | |
300 end; | |
301 save mlptrain.net net | |
302 zoom on; | |
303 | |
304 title(['Training complete. Final error=', num2str(error)]); | |
305 | |
306 end; | |
307 | |
308 elseif strcmp(action, 'close'), | |
309 | |
310 % Close all the figures we have created | |
311 close(DEMTRAIN_FIG); | |
312 for n = 1:NUM_DEMTRAIN_RES_FIGS | |
313 if ishandle(DEMTRAIN_RES_FIGS(n)) | |
314 close(DEMTRAIN_RES_FIGS(n)); | |
315 end | |
316 end | |
317 | |
318 elseif strcmp(action, 'classify'), | |
319 | |
320 if fopen('mlptrain.net') == -1, | |
321 errordlg('You have not yet trained the network.', 'Error'); | |
322 else | |
323 | |
324 hndlList = get(gcf, 'UserData'); | |
325 filename = get(hndlList(3), 'String'); | |
326 path = get(hndlList(4), 'String'); | |
327 [x,t,nin,nout,ndata] = datread([path filename]); | |
328 load mlptrain.net net -mat | |
329 y = mlpfwd(net, x); | |
330 | |
331 % Save results figure so that it can be closed later | |
332 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; | |
333 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS)=conffig(y,t); | |
334 | |
335 end; | |
336 | |
337 elseif strcmp(action, 'predict'), | |
338 | |
339 if fopen('mlptrain.net') == -1, | |
340 errordlg('You have not yet trained the network.', 'Error'); | |
341 else | |
342 | |
343 hndlList = get(gcf, 'UserData'); | |
344 filename = get(hndlList(3), 'String'); | |
345 path = get(hndlList(4), 'String'); | |
346 [x,t,nin,nout,ndata] = datread([path filename]); | |
347 load mlptrain.net net -mat | |
348 y = mlpfwd(net, x); | |
349 | |
350 for i = 1:size(y,2), | |
351 % Save results figure so that it can be closed later | |
352 NUM_DEMTRAIN_RES_FIGS = NUM_DEMTRAIN_RES_FIGS + 1; | |
353 DEMTRAIN_RES_FIGS(NUM_DEMTRAIN_RES_FIGS) = figure; | |
354 hold on; | |
355 title(['Output no ' num2str(i)]); | |
356 plot([0 1], [0 1], 'r:'); | |
357 plot(y(:,i),t(:,i), 'o'); | |
358 hold off; | |
359 end; | |
360 end; | |
361 | |
362 end; |