Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/examples/static/HME/hmemenu.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 % dataset -> (1=>user data) or (2=>toy example) | |
2 % type -> (1=> Regression model) or (2=>Classification model) | |
3 % num_glevel -> number of hidden nodes in the net (gating levels) | |
4 % num_exp -> number of experts in the net | |
5 % branch_fact -> dimension of the hidden nodes in the net | |
6 % cov_dim -> root node dimension | |
7 % res_dim -> output node dimension | |
8 % nodes_info -> 4 x num_glevel+2 matrix that contain all the info about the nodes: | |
9 % nodes_info(1,:) = nodes type: (0=>gaussian)or(1=>softmax)or(2=>mlp) | |
10 % nodes_info(2,:) = nodes size: [cov_dim num_glevel x branch_fact res_dim] | |
11 % nodes_info(3,:) = hidden units number (for mlp nodes) | |
12 % |- optimizer iteration number (for softmax & mlp CPD) | |
13 % nodes_info(4,:) =|- covariance type (for gaussian CPD)-> | |
14 % | (1=>Full)or(2=>Diagonal)or(3=>Full&Tied)or(4=>Diagonal&Tied) | |
15 % fh1 -> Figure: data & decizion boundaries; fh2 -> confusion matrix; fh3 -> LL trace | |
16 % test_data -> test data matrix | |
17 % train_data -> training data matrix | |
18 % ntrain -> size(train_data,2) | |
19 % ntest -> size(test_data,2) | |
20 % cases -> (cell array) training data formatted for the learning engine | |
21 % bnet -> bayesian net before learning | |
22 % bnet2 -> bayesian net after learning | |
23 % ll -> log-likelihood before learning | |
24 % LL2 -> log-likelihood trace | |
25 % onodes -> obs nodes in bnet & bnet2 | |
26 % max_em_iter -> maximum number of interations of the EM algorithm | |
27 % train_result -> prediction on the training set (as test_result) | |
28 % | |
29 % IMPORTANT: CHECK the loading path (lines 64 & 364) | |
30 % ---------------------------------------------------------------------------------------------------- | |
31 % -> pierpaolo_b@hotmail.com or -> pampo@interfree.it | |
32 % ---------------------------------------------------------------------------------------------------- | |
33 | |
34 error('this no longer works with the latest version of BNT') | |
35 | |
36 clear all; | |
37 clc; | |
38 disp('---------------------------------------------------'); | |
39 disp(' Hierarchical Mixtures of Experts models builder '); | |
40 disp('---------------------------------------------------'); | |
41 disp(' ') | |
42 disp(' Using this script you can build both an HME model') | |
43 disp('as in [Wat94] and [Jor94] i.e. with ''softmax'' gating') | |
44 disp('nodes and ''gaussian'' ( for regression ) or ''softmax''') | |
45 disp('( for classification ) expert node, and its variants') | |
46 disp('called ''gated nets'' where we use ''mlp'' models in') | |
47 disp('place of a number of ''softmax'' ones [Mor98], [Wei95].') | |
48 disp(' You can decide to train and test the model on your') | |
49 disp('datasets or to evaluate its performance on a toy') | |
50 disp('example.') | |
51 disp(' ') | |
52 disp('Reference') | |
53 disp('[Mor98] P. Moerland (1998):') | |
54 disp(' Localized mixtures of experts. (http://www.idiap.ch/~perry/)') | |
55 disp('[Jor94] M.I. Jordan, R.A. Jacobs (1994):') | |
56 disp(' HME and the EM algorithm. (http://www.cs.berkeley.edu/~jordan/)') | |
57 disp('[Wat94] S.R. Waterhouse, A.J. Robinson (1994):') | |
58 disp(' Classification using HME. (http://www.oigeeza.com/steve/)') | |
59 disp('[Wei95] A.S. Weigend, M. Mangeas (1995):') | |
60 disp(' Nonlinear gated experts for time series.') | |
61 disp(' ') | |
62 | |
63 if 0 | |
64 disp('(See the figure)') | |
65 pause(5); | |
66 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
67 im_path=which('HMEforMatlab.jpg'); | |
68 fig=imread(im_path, 'jpg'); | |
69 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
70 figure('Units','pixels','MenuBar','none','NumberTitle','off', 'Name', 'HME model'); | |
71 image(fig); | |
72 axis image; | |
73 axis off; | |
74 clear fig; | |
75 set(gca,'Position',[0 0 1 1]) | |
76 disp('(Press any key to continue)') | |
77 pause | |
78 end | |
79 | |
80 clc | |
81 disp('---------------------------------------------------'); | |
82 disp(' Specify the Architecture '); | |
83 disp('---------------------------------------------------'); | |
84 disp(' '); | |
85 disp('What kind of model do you need?') | |
86 disp(' ') | |
87 disp('1) Regression ') | |
88 disp('2) Classification') | |
89 disp(' ') | |
90 type=input('1 or 2?: '); | |
91 if (isempty(type)|(~ismember(type,[1 2]))), error('Invalid value'); end | |
92 clc | |
93 disp('----------------------------------------------------'); | |
94 disp(' Specify the Architecture '); | |
95 disp('----------------------------------------------------'); | |
96 disp(' ') | |
97 disp('Now you have to set the number of experts and gating') | |
98 disp('levels in the net. This script builds only balanced') | |
99 disp('hierarchy with the same branching factor (>1)at each') | |
100 disp('(gating) level. So remember that: ') | |
101 disp(' ') | |
102 disp(' num_exp = branch_fact^num_glevel ') | |
103 disp(' ') | |
104 disp('with branch_fact >=2.') | |
105 disp('You can also set to zeros the number of gating level') | |
106 disp('in order to obtain a classical GLM model. ') | |
107 disp(' ') | |
108 disp('----------------------------------------------------'); | |
109 disp(' ') | |
110 num_glevel=input('Insert the number of gating levels {0,...,20}: '); | |
111 if (isempty(num_glevel)|(~ismember(num_glevel,[0:20]))), error('Invalid value'); end | |
112 nodes_info=zeros(4,num_glevel+2); | |
113 if num_glevel>0, %------------------------------------------------------------------------------------ | |
114 for i=2:num_glevel+1, | |
115 clc | |
116 disp('----------------------------------------------------'); | |
117 disp(' Specify the Architecture '); | |
118 disp('----------------------------------------------------'); | |
119 disp(' ') | |
120 disp(['-> Gating network ', num2str(i-1), ' is a: ']) | |
121 disp(' ') | |
122 disp(' 1) Softmax model'); | |
123 disp(' 2) Two layer perceptron model') | |
124 disp(' ') | |
125 nodes_info(1,i)=input('1 or 2?: '); | |
126 if (isempty(nodes_info(1,i))|(~ismember(nodes_info(1,i),[1 2]))), error('Invalid value'); end | |
127 disp(' ') | |
128 if nodes_info(1,i)==2, | |
129 nodes_info(3,i)=input('Insert the number of units in the hidden layer: '); | |
130 if (isempty(nodes_info(3,i))|(floor(nodes_info(3,i))~=nodes_info(3,i))|(nodes_info(3,i)<=0)), | |
131 error(['Invalid value: ', num2str(nodes_info(3,i)), ' is not a positive integer!']); | |
132 end | |
133 disp(' ') | |
134 end | |
135 nodes_info(4,i)=input('Insert the optimizer iteration number: '); | |
136 if (isempty(nodes_info(4,i))|(floor(nodes_info(4,i))~=nodes_info(4,i))|(nodes_info(4,i)<=0)), | |
137 error(['Invalid value: ', num2str(nodes_info(4,i)), ' is not a positive integer!']); | |
138 end | |
139 end | |
140 clc | |
141 disp('---------------------------------------------------------'); | |
142 disp(' Specify the Architecture '); | |
143 disp('---------------------------------------------------------'); | |
144 disp(' ') | |
145 disp('Now you have to set the number of experts in the network'); | |
146 disp('The value will be adjusted in order to obtain a hierarchy'); | |
147 disp('as said above.') | |
148 disp(' '); | |
149 num_exp=input(['Insert the approximative number of experts (>=', num2str(2^num_glevel), '): ']); | |
150 if (isempty(num_exp)|(num_exp<=0)|(num_exp<2^num_glevel)), | |
151 error('Invalid value'); | |
152 end | |
153 app1=0; base=2; | |
154 while app1<num_exp, | |
155 app1=base^num_glevel; | |
156 base=base+1; | |
157 end | |
158 app2=(base-2)^num_glevel; | |
159 branch_fact=base-1; | |
160 if app2>=(2^num_glevel)&(abs(app2-num_exp)<abs(app1-num_exp)), | |
161 branch_fact=base-2; | |
162 end | |
163 clear app1 app2 base; | |
164 disp(' ') | |
165 disp(['The effective number of experts in the net is: ', num2str(branch_fact^num_glevel), '.']) | |
166 disp(' '); | |
167 else | |
168 clc | |
169 disp('---------------------------------------------------------'); | |
170 disp(' Specify the Architecture (GLM model) '); | |
171 disp('---------------------------------------------------------'); | |
172 disp(' ') | |
173 end % END of: if num_glevel>0------------------------------------------------------------------------- | |
174 | |
175 if type==2, | |
176 disp(['-> Expert node is a: ']) | |
177 disp(' ') | |
178 disp(' 1) Softmax model'); | |
179 disp(' 2) Two layer perceptron model') | |
180 disp(' ') | |
181 nodes_info(1,end)=input('1 or 2?: '); | |
182 if (isempty(nodes_info(1,end))|(~ismember(nodes_info(1,end),[1 2]))), | |
183 error('Invalid value'); | |
184 end | |
185 disp(' ') | |
186 if nodes_info(1,end)==2, | |
187 nodes_info(3,end)=input('Insert the number of units in the hidden layer: '); | |
188 if (isempty(nodes_info(3,end))|(floor(nodes_info(3,end))~=nodes_info(3,end))|(nodes_info(3,end)<=0)), | |
189 error(['Invalid value: ', num2str(nodes_info(3,end)), ' is not a positive integer!']); | |
190 end | |
191 disp(' ') | |
192 end | |
193 nodes_info(4,end)=input('Insert the optimizer iteration number: '); | |
194 if (isempty(nodes_info(4,end))|(floor(nodes_info(4,end))~=nodes_info(4,end))|(nodes_info(4,end)<=0)), | |
195 error(['Invalid value: ', num2str(nodes_info(4,end)), ' is not a positive integer!']); | |
196 end | |
197 elseif type==1, | |
198 disp('What kind of covariance matrix structure do you want?') | |
199 disp(' ') | |
200 disp(' 1) Full'); | |
201 disp(' 2) Diagonal') | |
202 disp(' 3) Full & Tied'); | |
203 disp(' 4) Diagonal & Tied') | |
204 | |
205 disp(' ') | |
206 nodes_info(4,end)=input('1, 2, 3 or 4?: '); | |
207 if (isempty(nodes_info(4,end))|(~ismember(nodes_info(4,end),[1 2 3 4]))), | |
208 error('Invalid value'); | |
209 end | |
210 end | |
211 clc | |
212 disp('----------------------------------------------------'); | |
213 disp(' Specify the Input '); | |
214 disp('----------------------------------------------------'); | |
215 disp(' ') | |
216 disp('Do you want to...') | |
217 disp(' ') | |
218 disp('1) ...use your own dataset?') | |
219 disp('2) ...apply the model on a toy example?') | |
220 disp(' ') | |
221 dataset=input('1 or 2?: '); | |
222 if (isempty(dataset)|(~ismember(dataset,[1 2]))), error('Invalid value'); end | |
223 if dataset==1, | |
224 if type==1, | |
225 clc | |
226 disp('-------------------------------------------------------'); | |
227 disp(' Specify the Input - Regression problem '); | |
228 disp('-------------------------------------------------------'); | |
229 disp(' ') | |
230 disp('Be sure that each row of your data matrix is an example'); | |
231 disp('with the covariate values that precede the respond ones') | |
232 disp(' ') | |
233 disp('-------------------------------------------------------'); | |
234 disp(' ') | |
235 cov_dim=input('Insert the covariate space dimension: '); | |
236 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)), | |
237 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']); | |
238 end | |
239 disp(' ') | |
240 res_dim=input('Insert the dimension of the respond variable: '); | |
241 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)), | |
242 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']); | |
243 end | |
244 disp(' '); | |
245 elseif type==2 | |
246 clc | |
247 disp('-------------------------------------------------------'); | |
248 disp(' Specify the Input - Classification problem '); | |
249 disp('-------------------------------------------------------'); | |
250 disp(' ') | |
251 disp('Be sure that each row of your data matrix is an example'); | |
252 disp('with the covariate values that precede the class labels'); | |
253 disp('(integer value >=1). '); | |
254 disp(' ') | |
255 disp('-------------------------------------------------------'); | |
256 disp(' ') | |
257 cov_dim=input('Insert the covariate space dimension: '); | |
258 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)), | |
259 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']); | |
260 end | |
261 disp(' ') | |
262 res_dim=input('Insert the number of classes: '); | |
263 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)), | |
264 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']); | |
265 end | |
266 disp(' ') | |
267 end | |
268 % ------------------------------------------------------------------------------------------------ | |
269 % Loading training data -------------------------------------------------------------------------- | |
270 % ------------------------------------------------------------------------------------------------ | |
271 train_path=input('Insert the complete (with extension) path of the training data file:\n >> ','s'); | |
272 if isempty(train_path), error('You must specify a data set for training!'); end | |
273 if ~isempty(findstr('.mat',train_path)), | |
274 ap=load(train_path); app=fieldnames(ap); train_data=eval(['ap.', app{1,1}]); | |
275 clear ap app; | |
276 elseif ~isempty(findstr('.txt',train_path)), | |
277 train_data=load(train_path, '-ascii'); | |
278 else | |
279 error('Invalid data format: not a .mat or a .txt file') | |
280 end | |
281 if (size(train_data,2)~=cov_dim+res_dim)&(type==1), | |
282 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',... | |
283 num2str(cov_dim+res_dim),'!']); | |
284 elseif (size(train_data,2)~=cov_dim+1)&(type==2), | |
285 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',... | |
286 num2str(cov_dim+1),'!']); | |
287 elseif (~isempty(find(ismember(intersect([train_data(:,end)' 1:res_dim],... | |
288 train_data(:,end)'),[1:res_dim])==0)))&(type==2), | |
289 error('Invalid class label'); | |
290 end | |
291 ntrain=size(train_data,1); | |
292 train_d=train_data(:,1:cov_dim); | |
293 if type==2, | |
294 train_t=zeros(ntrain, res_dim); | |
295 for m=1:res_dim, | |
296 train_t((find(train_data(:,end)==m))',m)=1; | |
297 end | |
298 else | |
299 train_t=train_data(:,cov_dim+1:end); | |
300 end | |
301 disp(' ') | |
302 % ------------------------------------------------------------------------------------------------ | |
303 % Loading test data ------------------------------------------------------------------------------ | |
304 % ------------------------------------------------------------------------------------------------ | |
305 disp('(If you don''t want to specify a test-set press ''return'' only)'); | |
306 test_path=input('Insert the complete (with extension) path of the test data file:\n >> ','s'); | |
307 if ~isempty(test_path), | |
308 if ~isempty(findstr('.mat',test_path)), | |
309 ap=load(test_path); app=fieldnames(ap); test_data=eval(['ap.', app{1,1}]); | |
310 clear ap app; | |
311 elseif ~isempty(findstr('.txt',test_path)), | |
312 test_data=load(test_path, '-ascii'); | |
313 else | |
314 error('Invalid data format: not a .mat or a .txt file') | |
315 end | |
316 if (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+res_dim)&(type==1), | |
317 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',... | |
318 num2str(cov_dim+res_dim), ' or ', num2str(cov_dim), '!']); | |
319 elseif (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+1)&(type==2), | |
320 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',... | |
321 num2str(cov_dim+1), ' or ', num2str(cov_dim), '!']); | |
322 elseif (~isempty(find(ismember(intersect([test_data(:,end)' 1:res_dim],... | |
323 test_data(:,end)'),[1:res_dim])==0)))&(type==2)&(size(test_data,2)==cov_dim+1), | |
324 error('Invalid class label'); | |
325 end | |
326 ntest=size(test_data,1); | |
327 test_d=test_data(:,1:cov_dim); | |
328 if (type==2)&(size(test_data,2)>cov_dim), | |
329 test_t=zeros(ntest, res_dim); | |
330 for m=1:res_dim, | |
331 test_t((find(test_data(:,end)==m))',m)=1; | |
332 end | |
333 elseif (type==1)&(size(test_data,2)>cov_dim), | |
334 test_t=test_data(:,cov_dim+1:end); | |
335 end | |
336 disp(' '); | |
337 end | |
338 else | |
339 clc | |
340 disp('----------------------------------------------------'); | |
341 disp(' Specify the Input '); | |
342 disp('----------------------------------------------------'); | |
343 disp(' ') | |
344 ntrain = input('Insert the number of examples in training (<500): '); | |
345 if (isempty(ntrain)|(floor(ntrain)~=ntrain)|(ntrain<=0)|(ntrain>500)), | |
346 error(['Invalid value: ', num2str(ntrain), ' is not a positive integer <500!']); | |
347 end | |
348 disp(' ') | |
349 test_path='toy'; | |
350 ntest = input('Insert the number of examples in test (<500): '); | |
351 if (isempty(ntest)|(floor(ntest)~=ntest)|(ntest<=0)|(ntest>500)), | |
352 error(['Invalid value: ', num2str(ntest), ' is not a positive integer <500!']); | |
353 end | |
354 | |
355 if type==2, | |
356 cov_dim=2; | |
357 res_dim=3; | |
358 seed = 42; | |
359 [train_d, ntrain1, ntrain2, train_t]=gen_data(ntrain, seed); | |
360 for m=1:ntrain | |
361 q=[]; q = find(train_t(m,:)==1); | |
362 train_data(m,:)=[train_d(m,:) q]; | |
363 end | |
364 [test_d, ntest1, ntest2, test_t]=gen_data(ntest); | |
365 for m=1:ntest | |
366 q=[]; q = find(test_t(m,:)==1); | |
367 test_data(m,:)=[test_d(m,:) q]; | |
368 end | |
369 else | |
370 cov_dim=1; | |
371 res_dim=1; | |
372 global HOME | |
373 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
374 load([HOME '/examples/static/Misc/mixexp_data.txt'], '-ascii'); | |
375 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
376 train_data = mixexp_data(1:ntrain, :); | |
377 train_d=train_data(:,1:cov_dim); train_t=train_data(:,cov_dim+1:end); | |
378 test_data = mixexp_data(ntrain+1:ntrain+ntest, :); | |
379 test_d=test_data(:,1:cov_dim); | |
380 if size(test_data,2)>cov_dim, | |
381 test_t=test_data(:,cov_dim+1:end); | |
382 end | |
383 end | |
384 end | |
385 % Set the nodes dimension----------------------------------- | |
386 if num_glevel>0, | |
387 nodes_info(2,2:num_glevel+1)=branch_fact; | |
388 end | |
389 nodes_info(2,1)=cov_dim; nodes_info(2,end)=res_dim; | |
390 %----------------------------------------------------------- | |
391 % Prepare the training data for the learning engine--------- | |
392 %----------------------------------------------------------- | |
393 cases = cell(size(nodes_info,2), ntrain); | |
394 for m=1:ntrain, | |
395 cases{1,m}=train_data(m,1:cov_dim)'; | |
396 cases{end,m}=train_data(m,cov_dim+1:end)'; | |
397 end | |
398 %----------------------------------------------------------------------------------------------------- | |
399 [bnet onodes]=hme_topobuilder(nodes_info); | |
400 engine = jtree_inf_engine(bnet, onodes); | |
401 clc | |
402 disp('---------------------------------------------------------------------'); | |
403 disp(' L E A R N I N G '); | |
404 disp('---------------------------------------------------------------------'); | |
405 disp(' ') | |
406 ll = 0; | |
407 for l=1:ntrain | |
408 scritta=['example number: ', int2str(l),'---------------------------------------------']; | |
409 disp(scritta); | |
410 ev = cases(:,l); | |
411 [engine, loglik] = enter_evidence(engine, ev); | |
412 ll = ll + loglik; | |
413 end | |
414 disp(' ') | |
415 disp(['Log-likelihood before learning: ', num2str(ll)]); | |
416 disp(' ') | |
417 disp('(Press any key to continue)'); | |
418 pause | |
419 %----------------------------------------------------------- | |
420 clc | |
421 disp('---------------------------------------------------------------------'); | |
422 disp(' L E A R N I N G '); | |
423 disp('---------------------------------------------------------------------'); | |
424 disp(' ') | |
425 max_em_iter=input('Insert the maximum number of the EM algorithm iterations: '); | |
426 if (isempty(max_em_iter)|(floor(max_em_iter)~=max_em_iter)|(max_em_iter<=1)), | |
427 error(['Invalid value: ', num2str(ntest), ' is not a positive integer >1!']); | |
428 end | |
429 disp(' ') | |
430 disp(['Log-likelihood before learning: ', num2str(ll)]); | |
431 disp(' ') | |
432 | |
433 [bnet2, LL2] = learn_params_em(engine, cases, max_em_iter); | |
434 disp(' ') | |
435 fprintf('HME: loglik before learning %f, after %d iters %f\n', ll, length(LL2), LL2(end)); | |
436 disp(' ') | |
437 disp('(Press any key to continue)'); | |
438 pause | |
439 %----------------------------------------------------------------------------------- | |
440 % Classification problem: plot data & decision boundaries if the input data size = 2 | |
441 % Regression problem: plot data & prediction if the input data size = 1 | |
442 %----------------------------------------------------------------------------------- | |
443 if (type==2)&(nodes_info(2,1)==2)&(~isempty(test_path)), | |
444 fh1=hme_class_plot(bnet2, nodes_info, train_data, test_data); | |
445 disp(' '); | |
446 disp('(See the figure)'); | |
447 elseif (type==2)&(nodes_info(2,1)==2)&(isempty(test_path)), | |
448 fh1=hme_class_plot(bnet2, nodes_info, train_data); | |
449 disp(' '); | |
450 disp('(See the figure)'); | |
451 elseif (type==1)&(nodes_info(2,1)==1)&(~isempty(test_path)), | |
452 fh1=hme_reg_plot(bnet2, nodes_info, train_data, test_data); | |
453 disp(' '); | |
454 disp('(See the figure)'); | |
455 elseif (type==1)&(nodes_info(2,1)==1)&(isempty(test_path)), | |
456 fh1=hme_reg_plot(bnet2, nodes_info, train_data); | |
457 disp(' ') | |
458 disp('(See the figure)'); | |
459 end | |
460 %----------------------------------------------------------------------------------- | |
461 % Classification problem: plot confusion matrix | |
462 %----------------------------------------------------------------------------------- | |
463 if (type==2) | |
464 ztrain=fhme(bnet2, nodes_info, train_d, size(train_d,1)); | |
465 [Htrain, trainRate]=confmat(ztrain, train_t); % CM on the training set | |
466 fh2=figure('Name','Confusion matrix', 'MenuBar', 'none', 'NumberTitle', 'off'); | |
467 if (~isempty(test_path))&(size(test_data,2)>cov_dim), | |
468 ztest=fhme(bnet2, nodes_info, test_d, size(test_d,1)); | |
469 [Htest, testRate]=confmat(ztest, test_t); % CM on the test set | |
470 subplot(1,2,1); | |
471 end | |
472 plotmat(Htrain,'b','k',12) | |
473 tick=[0.5:1:(0.5+nodes_info(2,end)-1)]; | |
474 set(gca,'XTick',tick) | |
475 set(gca,'YTick',tick) | |
476 grid('off') | |
477 ylabel('True') | |
478 xlabel('Prediction') | |
479 title(['Confusion Matrix: training set (' num2str(trainRate(1)) '%)']) | |
480 if (~isempty(test_path))&(size(test_data,2)>cov_dim), | |
481 subplot(1,2,2) | |
482 plotmat(Htest,'b','k',12) | |
483 set(gca,'XTick',tick) | |
484 set(gca,'YTick',tick) | |
485 grid('off') | |
486 ylabel('True') | |
487 xlabel('Prediction') | |
488 title(['Confusion Matrix: test set (' num2str(testRate(1)) '%)']) | |
489 end | |
490 disp(' ') | |
491 disp('(Press any key to continue)'); | |
492 pause | |
493 end | |
494 %----------------------------------------------------------------------------------- | |
495 % Regression & Classification problem: calculate the predictions & plot the LL trace | |
496 %----------------------------------------------------------------------------------- | |
497 train_result=fhme(bnet2,nodes_info,train_d,size(train_d,1)); | |
498 if ~isempty(test_path), | |
499 test_result=fhme(bnet2,nodes_info,test_d,size(test_d,1)); | |
500 end | |
501 fh3=figure('Name','Log-likelihood trace', 'MenuBar', 'none', 'NumberTitle', 'off') | |
502 plot(LL2,'-ro',... | |
503 'MarkerEdgeColor','k',... | |
504 'MarkerFaceColor',[1 1 0],... | |
505 'MarkerSize',4) | |
506 title('Log-likelihood trace') | |
507 %----------------------------------------------------------------------------------- | |
508 % Regression & Classification problem: save the predictions | |
509 %----------------------------------------------------------------------------------- | |
510 clc | |
511 disp('------------------------------------------------------------------'); | |
512 disp(' Save the results '); | |
513 disp('------------------------------------------------------------------'); | |
514 disp(' ') | |
515 %----------------------------------------------------------------------------------- | |
516 save_quest_m=input('Do you want to save the HME model (Y/N)? [Y default]: ', 's'); | |
517 if isempty(save_quest_m), | |
518 save_quest_m='Y'; | |
519 end | |
520 if ~findstr(save_quest_m, ['Y', 'N']), error('Invalid input'); end | |
521 if save_quest_m=='Y', | |
522 disp(' '); | |
523 m_save=input('Insert the complete path for save the HME model (.mat):\n >> ', 's'); | |
524 if isempty(m_save), error('You must specify a path!'); end | |
525 save(m_save, 'bnet2'); | |
526 end | |
527 %----------------------------------------------------------------------------------- | |
528 disp(' ') | |
529 save_quest=input('Do you want to save the HME predictions (Y/N)? [Y default]: ', 's'); | |
530 disp(' ') | |
531 if isempty(save_quest), | |
532 save_quest='Y'; | |
533 end | |
534 if ~findstr(save_quest, ['Y', 'N']), error('Invalid input'); end | |
535 if save_quest=='Y', | |
536 tr_save=input('Insert the complete path for save the training data prediction (.mat):\n >> ', 's'); | |
537 if isempty(tr_save), error('You must specify a path!'); end | |
538 save(tr_save, 'train_result'); | |
539 if ~isempty(test_path), | |
540 disp(' ') | |
541 te_save=input('Insert the complete path for save the test data prediction (.mat):\n >> ', 's'); | |
542 if isempty(te_save), error('You must specify a path!'); end | |
543 save(te_save, 'test_result'); | |
544 end | |
545 end | |
546 clc | |
547 disp('----------------------------------------------------'); | |
548 disp(' B Y E ! '); | |
549 disp('----------------------------------------------------'); | |
550 pause(2) | |
551 %clear | |
552 clc |