wolffd@0
|
1 % dataset -> (1=>user data) or (2=>toy example)
|
wolffd@0
|
2 % type -> (1=> Regression model) or (2=>Classification model)
|
wolffd@0
|
3 % num_glevel -> number of hidden nodes in the net (gating levels)
|
wolffd@0
|
4 % num_exp -> number of experts in the net
|
wolffd@0
|
5 % branch_fact -> dimension of the hidden nodes in the net
|
wolffd@0
|
6 % cov_dim -> root node dimension
|
wolffd@0
|
7 % res_dim -> output node dimension
|
wolffd@0
|
8 % nodes_info -> 4 x num_glevel+2 matrix that contain all the info about the nodes:
|
wolffd@0
|
9 % nodes_info(1,:) = nodes type: (0=>gaussian)or(1=>softmax)or(2=>mlp)
|
wolffd@0
|
10 % nodes_info(2,:) = nodes size: [cov_dim num_glevel x branch_fact res_dim]
|
wolffd@0
|
11 % nodes_info(3,:) = hidden units number (for mlp nodes)
|
wolffd@0
|
12 % |- optimizer iteration number (for softmax & mlp CPD)
|
wolffd@0
|
13 % nodes_info(4,:) =|- covariance type (for gaussian CPD)->
|
wolffd@0
|
14 % | (1=>Full)or(2=>Diagonal)or(3=>Full&Tied)or(4=>Diagonal&Tied)
|
wolffd@0
|
15 % fh1 -> Figure: data & decizion boundaries; fh2 -> confusion matrix; fh3 -> LL trace
|
wolffd@0
|
16 % test_data -> test data matrix
|
wolffd@0
|
17 % train_data -> training data matrix
|
wolffd@0
|
18 % ntrain -> size(train_data,2)
|
wolffd@0
|
19 % ntest -> size(test_data,2)
|
wolffd@0
|
20 % cases -> (cell array) training data formatted for the learning engine
|
wolffd@0
|
21 % bnet -> bayesian net before learning
|
wolffd@0
|
22 % bnet2 -> bayesian net after learning
|
wolffd@0
|
23 % ll -> log-likelihood before learning
|
wolffd@0
|
24 % LL2 -> log-likelihood trace
|
wolffd@0
|
25 % onodes -> obs nodes in bnet & bnet2
|
wolffd@0
|
26 % max_em_iter -> maximum number of interations of the EM algorithm
|
wolffd@0
|
27 % train_result -> prediction on the training set (as test_result)
|
wolffd@0
|
28 %
|
wolffd@0
|
29 % IMPORTANT: CHECK the loading path (lines 64 & 364)
|
wolffd@0
|
30 % ----------------------------------------------------------------------------------------------------
|
wolffd@0
|
31 % -> pierpaolo_b@hotmail.com or -> pampo@interfree.it
|
wolffd@0
|
32 % ----------------------------------------------------------------------------------------------------
|
wolffd@0
|
33
|
wolffd@0
|
34 error('this no longer works with the latest version of BNT')
|
wolffd@0
|
35
|
wolffd@0
|
36 clear all;
|
wolffd@0
|
37 clc;
|
wolffd@0
|
38 disp('---------------------------------------------------');
|
wolffd@0
|
39 disp(' Hierarchical Mixtures of Experts models builder ');
|
wolffd@0
|
40 disp('---------------------------------------------------');
|
wolffd@0
|
41 disp(' ')
|
wolffd@0
|
42 disp(' Using this script you can build both an HME model')
|
wolffd@0
|
43 disp('as in [Wat94] and [Jor94] i.e. with ''softmax'' gating')
|
wolffd@0
|
44 disp('nodes and ''gaussian'' ( for regression ) or ''softmax''')
|
wolffd@0
|
45 disp('( for classification ) expert node, and its variants')
|
wolffd@0
|
46 disp('called ''gated nets'' where we use ''mlp'' models in')
|
wolffd@0
|
47 disp('place of a number of ''softmax'' ones [Mor98], [Wei95].')
|
wolffd@0
|
48 disp(' You can decide to train and test the model on your')
|
wolffd@0
|
49 disp('datasets or to evaluate its performance on a toy')
|
wolffd@0
|
50 disp('example.')
|
wolffd@0
|
51 disp(' ')
|
wolffd@0
|
52 disp('Reference')
|
wolffd@0
|
53 disp('[Mor98] P. Moerland (1998):')
|
wolffd@0
|
54 disp(' Localized mixtures of experts. (http://www.idiap.ch/~perry/)')
|
wolffd@0
|
55 disp('[Jor94] M.I. Jordan, R.A. Jacobs (1994):')
|
wolffd@0
|
56 disp(' HME and the EM algorithm. (http://www.cs.berkeley.edu/~jordan/)')
|
wolffd@0
|
57 disp('[Wat94] S.R. Waterhouse, A.J. Robinson (1994):')
|
wolffd@0
|
58 disp(' Classification using HME. (http://www.oigeeza.com/steve/)')
|
wolffd@0
|
59 disp('[Wei95] A.S. Weigend, M. Mangeas (1995):')
|
wolffd@0
|
60 disp(' Nonlinear gated experts for time series.')
|
wolffd@0
|
61 disp(' ')
|
wolffd@0
|
62
|
wolffd@0
|
63 if 0
|
wolffd@0
|
64 disp('(See the figure)')
|
wolffd@0
|
65 pause(5);
|
wolffd@0
|
66 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
67 im_path=which('HMEforMatlab.jpg');
|
wolffd@0
|
68 fig=imread(im_path, 'jpg');
|
wolffd@0
|
69 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
70 figure('Units','pixels','MenuBar','none','NumberTitle','off', 'Name', 'HME model');
|
wolffd@0
|
71 image(fig);
|
wolffd@0
|
72 axis image;
|
wolffd@0
|
73 axis off;
|
wolffd@0
|
74 clear fig;
|
wolffd@0
|
75 set(gca,'Position',[0 0 1 1])
|
wolffd@0
|
76 disp('(Press any key to continue)')
|
wolffd@0
|
77 pause
|
wolffd@0
|
78 end
|
wolffd@0
|
79
|
wolffd@0
|
80 clc
|
wolffd@0
|
81 disp('---------------------------------------------------');
|
wolffd@0
|
82 disp(' Specify the Architecture ');
|
wolffd@0
|
83 disp('---------------------------------------------------');
|
wolffd@0
|
84 disp(' ');
|
wolffd@0
|
85 disp('What kind of model do you need?')
|
wolffd@0
|
86 disp(' ')
|
wolffd@0
|
87 disp('1) Regression ')
|
wolffd@0
|
88 disp('2) Classification')
|
wolffd@0
|
89 disp(' ')
|
wolffd@0
|
90 type=input('1 or 2?: ');
|
wolffd@0
|
91 if (isempty(type)|(~ismember(type,[1 2]))), error('Invalid value'); end
|
wolffd@0
|
92 clc
|
wolffd@0
|
93 disp('----------------------------------------------------');
|
wolffd@0
|
94 disp(' Specify the Architecture ');
|
wolffd@0
|
95 disp('----------------------------------------------------');
|
wolffd@0
|
96 disp(' ')
|
wolffd@0
|
97 disp('Now you have to set the number of experts and gating')
|
wolffd@0
|
98 disp('levels in the net. This script builds only balanced')
|
wolffd@0
|
99 disp('hierarchy with the same branching factor (>1)at each')
|
wolffd@0
|
100 disp('(gating) level. So remember that: ')
|
wolffd@0
|
101 disp(' ')
|
wolffd@0
|
102 disp(' num_exp = branch_fact^num_glevel ')
|
wolffd@0
|
103 disp(' ')
|
wolffd@0
|
104 disp('with branch_fact >=2.')
|
wolffd@0
|
105 disp('You can also set to zeros the number of gating level')
|
wolffd@0
|
106 disp('in order to obtain a classical GLM model. ')
|
wolffd@0
|
107 disp(' ')
|
wolffd@0
|
108 disp('----------------------------------------------------');
|
wolffd@0
|
109 disp(' ')
|
wolffd@0
|
110 num_glevel=input('Insert the number of gating levels {0,...,20}: ');
|
wolffd@0
|
111 if (isempty(num_glevel)|(~ismember(num_glevel,[0:20]))), error('Invalid value'); end
|
wolffd@0
|
112 nodes_info=zeros(4,num_glevel+2);
|
wolffd@0
|
113 if num_glevel>0, %------------------------------------------------------------------------------------
|
wolffd@0
|
114 for i=2:num_glevel+1,
|
wolffd@0
|
115 clc
|
wolffd@0
|
116 disp('----------------------------------------------------');
|
wolffd@0
|
117 disp(' Specify the Architecture ');
|
wolffd@0
|
118 disp('----------------------------------------------------');
|
wolffd@0
|
119 disp(' ')
|
wolffd@0
|
120 disp(['-> Gating network ', num2str(i-1), ' is a: '])
|
wolffd@0
|
121 disp(' ')
|
wolffd@0
|
122 disp(' 1) Softmax model');
|
wolffd@0
|
123 disp(' 2) Two layer perceptron model')
|
wolffd@0
|
124 disp(' ')
|
wolffd@0
|
125 nodes_info(1,i)=input('1 or 2?: ');
|
wolffd@0
|
126 if (isempty(nodes_info(1,i))|(~ismember(nodes_info(1,i),[1 2]))), error('Invalid value'); end
|
wolffd@0
|
127 disp(' ')
|
wolffd@0
|
128 if nodes_info(1,i)==2,
|
wolffd@0
|
129 nodes_info(3,i)=input('Insert the number of units in the hidden layer: ');
|
wolffd@0
|
130 if (isempty(nodes_info(3,i))|(floor(nodes_info(3,i))~=nodes_info(3,i))|(nodes_info(3,i)<=0)),
|
wolffd@0
|
131 error(['Invalid value: ', num2str(nodes_info(3,i)), ' is not a positive integer!']);
|
wolffd@0
|
132 end
|
wolffd@0
|
133 disp(' ')
|
wolffd@0
|
134 end
|
wolffd@0
|
135 nodes_info(4,i)=input('Insert the optimizer iteration number: ');
|
wolffd@0
|
136 if (isempty(nodes_info(4,i))|(floor(nodes_info(4,i))~=nodes_info(4,i))|(nodes_info(4,i)<=0)),
|
wolffd@0
|
137 error(['Invalid value: ', num2str(nodes_info(4,i)), ' is not a positive integer!']);
|
wolffd@0
|
138 end
|
wolffd@0
|
139 end
|
wolffd@0
|
140 clc
|
wolffd@0
|
141 disp('---------------------------------------------------------');
|
wolffd@0
|
142 disp(' Specify the Architecture ');
|
wolffd@0
|
143 disp('---------------------------------------------------------');
|
wolffd@0
|
144 disp(' ')
|
wolffd@0
|
145 disp('Now you have to set the number of experts in the network');
|
wolffd@0
|
146 disp('The value will be adjusted in order to obtain a hierarchy');
|
wolffd@0
|
147 disp('as said above.')
|
wolffd@0
|
148 disp(' ');
|
wolffd@0
|
149 num_exp=input(['Insert the approximative number of experts (>=', num2str(2^num_glevel), '): ']);
|
wolffd@0
|
150 if (isempty(num_exp)|(num_exp<=0)|(num_exp<2^num_glevel)),
|
wolffd@0
|
151 error('Invalid value');
|
wolffd@0
|
152 end
|
wolffd@0
|
153 app1=0; base=2;
|
wolffd@0
|
154 while app1<num_exp,
|
wolffd@0
|
155 app1=base^num_glevel;
|
wolffd@0
|
156 base=base+1;
|
wolffd@0
|
157 end
|
wolffd@0
|
158 app2=(base-2)^num_glevel;
|
wolffd@0
|
159 branch_fact=base-1;
|
wolffd@0
|
160 if app2>=(2^num_glevel)&(abs(app2-num_exp)<abs(app1-num_exp)),
|
wolffd@0
|
161 branch_fact=base-2;
|
wolffd@0
|
162 end
|
wolffd@0
|
163 clear app1 app2 base;
|
wolffd@0
|
164 disp(' ')
|
wolffd@0
|
165 disp(['The effective number of experts in the net is: ', num2str(branch_fact^num_glevel), '.'])
|
wolffd@0
|
166 disp(' ');
|
wolffd@0
|
167 else
|
wolffd@0
|
168 clc
|
wolffd@0
|
169 disp('---------------------------------------------------------');
|
wolffd@0
|
170 disp(' Specify the Architecture (GLM model) ');
|
wolffd@0
|
171 disp('---------------------------------------------------------');
|
wolffd@0
|
172 disp(' ')
|
wolffd@0
|
173 end % END of: if num_glevel>0-------------------------------------------------------------------------
|
wolffd@0
|
174
|
wolffd@0
|
175 if type==2,
|
wolffd@0
|
176 disp(['-> Expert node is a: '])
|
wolffd@0
|
177 disp(' ')
|
wolffd@0
|
178 disp(' 1) Softmax model');
|
wolffd@0
|
179 disp(' 2) Two layer perceptron model')
|
wolffd@0
|
180 disp(' ')
|
wolffd@0
|
181 nodes_info(1,end)=input('1 or 2?: ');
|
wolffd@0
|
182 if (isempty(nodes_info(1,end))|(~ismember(nodes_info(1,end),[1 2]))),
|
wolffd@0
|
183 error('Invalid value');
|
wolffd@0
|
184 end
|
wolffd@0
|
185 disp(' ')
|
wolffd@0
|
186 if nodes_info(1,end)==2,
|
wolffd@0
|
187 nodes_info(3,end)=input('Insert the number of units in the hidden layer: ');
|
wolffd@0
|
188 if (isempty(nodes_info(3,end))|(floor(nodes_info(3,end))~=nodes_info(3,end))|(nodes_info(3,end)<=0)),
|
wolffd@0
|
189 error(['Invalid value: ', num2str(nodes_info(3,end)), ' is not a positive integer!']);
|
wolffd@0
|
190 end
|
wolffd@0
|
191 disp(' ')
|
wolffd@0
|
192 end
|
wolffd@0
|
193 nodes_info(4,end)=input('Insert the optimizer iteration number: ');
|
wolffd@0
|
194 if (isempty(nodes_info(4,end))|(floor(nodes_info(4,end))~=nodes_info(4,end))|(nodes_info(4,end)<=0)),
|
wolffd@0
|
195 error(['Invalid value: ', num2str(nodes_info(4,end)), ' is not a positive integer!']);
|
wolffd@0
|
196 end
|
wolffd@0
|
197 elseif type==1,
|
wolffd@0
|
198 disp('What kind of covariance matrix structure do you want?')
|
wolffd@0
|
199 disp(' ')
|
wolffd@0
|
200 disp(' 1) Full');
|
wolffd@0
|
201 disp(' 2) Diagonal')
|
wolffd@0
|
202 disp(' 3) Full & Tied');
|
wolffd@0
|
203 disp(' 4) Diagonal & Tied')
|
wolffd@0
|
204
|
wolffd@0
|
205 disp(' ')
|
wolffd@0
|
206 nodes_info(4,end)=input('1, 2, 3 or 4?: ');
|
wolffd@0
|
207 if (isempty(nodes_info(4,end))|(~ismember(nodes_info(4,end),[1 2 3 4]))),
|
wolffd@0
|
208 error('Invalid value');
|
wolffd@0
|
209 end
|
wolffd@0
|
210 end
|
wolffd@0
|
211 clc
|
wolffd@0
|
212 disp('----------------------------------------------------');
|
wolffd@0
|
213 disp(' Specify the Input ');
|
wolffd@0
|
214 disp('----------------------------------------------------');
|
wolffd@0
|
215 disp(' ')
|
wolffd@0
|
216 disp('Do you want to...')
|
wolffd@0
|
217 disp(' ')
|
wolffd@0
|
218 disp('1) ...use your own dataset?')
|
wolffd@0
|
219 disp('2) ...apply the model on a toy example?')
|
wolffd@0
|
220 disp(' ')
|
wolffd@0
|
221 dataset=input('1 or 2?: ');
|
wolffd@0
|
222 if (isempty(dataset)|(~ismember(dataset,[1 2]))), error('Invalid value'); end
|
wolffd@0
|
223 if dataset==1,
|
wolffd@0
|
224 if type==1,
|
wolffd@0
|
225 clc
|
wolffd@0
|
226 disp('-------------------------------------------------------');
|
wolffd@0
|
227 disp(' Specify the Input - Regression problem ');
|
wolffd@0
|
228 disp('-------------------------------------------------------');
|
wolffd@0
|
229 disp(' ')
|
wolffd@0
|
230 disp('Be sure that each row of your data matrix is an example');
|
wolffd@0
|
231 disp('with the covariate values that precede the respond ones')
|
wolffd@0
|
232 disp(' ')
|
wolffd@0
|
233 disp('-------------------------------------------------------');
|
wolffd@0
|
234 disp(' ')
|
wolffd@0
|
235 cov_dim=input('Insert the covariate space dimension: ');
|
wolffd@0
|
236 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)),
|
wolffd@0
|
237 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']);
|
wolffd@0
|
238 end
|
wolffd@0
|
239 disp(' ')
|
wolffd@0
|
240 res_dim=input('Insert the dimension of the respond variable: ');
|
wolffd@0
|
241 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)),
|
wolffd@0
|
242 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']);
|
wolffd@0
|
243 end
|
wolffd@0
|
244 disp(' ');
|
wolffd@0
|
245 elseif type==2
|
wolffd@0
|
246 clc
|
wolffd@0
|
247 disp('-------------------------------------------------------');
|
wolffd@0
|
248 disp(' Specify the Input - Classification problem ');
|
wolffd@0
|
249 disp('-------------------------------------------------------');
|
wolffd@0
|
250 disp(' ')
|
wolffd@0
|
251 disp('Be sure that each row of your data matrix is an example');
|
wolffd@0
|
252 disp('with the covariate values that precede the class labels');
|
wolffd@0
|
253 disp('(integer value >=1). ');
|
wolffd@0
|
254 disp(' ')
|
wolffd@0
|
255 disp('-------------------------------------------------------');
|
wolffd@0
|
256 disp(' ')
|
wolffd@0
|
257 cov_dim=input('Insert the covariate space dimension: ');
|
wolffd@0
|
258 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)),
|
wolffd@0
|
259 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']);
|
wolffd@0
|
260 end
|
wolffd@0
|
261 disp(' ')
|
wolffd@0
|
262 res_dim=input('Insert the number of classes: ');
|
wolffd@0
|
263 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)),
|
wolffd@0
|
264 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']);
|
wolffd@0
|
265 end
|
wolffd@0
|
266 disp(' ')
|
wolffd@0
|
267 end
|
wolffd@0
|
268 % ------------------------------------------------------------------------------------------------
|
wolffd@0
|
269 % Loading training data --------------------------------------------------------------------------
|
wolffd@0
|
270 % ------------------------------------------------------------------------------------------------
|
wolffd@0
|
271 train_path=input('Insert the complete (with extension) path of the training data file:\n >> ','s');
|
wolffd@0
|
272 if isempty(train_path), error('You must specify a data set for training!'); end
|
wolffd@0
|
273 if ~isempty(findstr('.mat',train_path)),
|
wolffd@0
|
274 ap=load(train_path); app=fieldnames(ap); train_data=eval(['ap.', app{1,1}]);
|
wolffd@0
|
275 clear ap app;
|
wolffd@0
|
276 elseif ~isempty(findstr('.txt',train_path)),
|
wolffd@0
|
277 train_data=load(train_path, '-ascii');
|
wolffd@0
|
278 else
|
wolffd@0
|
279 error('Invalid data format: not a .mat or a .txt file')
|
wolffd@0
|
280 end
|
wolffd@0
|
281 if (size(train_data,2)~=cov_dim+res_dim)&(type==1),
|
wolffd@0
|
282 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
|
wolffd@0
|
283 num2str(cov_dim+res_dim),'!']);
|
wolffd@0
|
284 elseif (size(train_data,2)~=cov_dim+1)&(type==2),
|
wolffd@0
|
285 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
|
wolffd@0
|
286 num2str(cov_dim+1),'!']);
|
wolffd@0
|
287 elseif (~isempty(find(ismember(intersect([train_data(:,end)' 1:res_dim],...
|
wolffd@0
|
288 train_data(:,end)'),[1:res_dim])==0)))&(type==2),
|
wolffd@0
|
289 error('Invalid class label');
|
wolffd@0
|
290 end
|
wolffd@0
|
291 ntrain=size(train_data,1);
|
wolffd@0
|
292 train_d=train_data(:,1:cov_dim);
|
wolffd@0
|
293 if type==2,
|
wolffd@0
|
294 train_t=zeros(ntrain, res_dim);
|
wolffd@0
|
295 for m=1:res_dim,
|
wolffd@0
|
296 train_t((find(train_data(:,end)==m))',m)=1;
|
wolffd@0
|
297 end
|
wolffd@0
|
298 else
|
wolffd@0
|
299 train_t=train_data(:,cov_dim+1:end);
|
wolffd@0
|
300 end
|
wolffd@0
|
301 disp(' ')
|
wolffd@0
|
302 % ------------------------------------------------------------------------------------------------
|
wolffd@0
|
303 % Loading test data ------------------------------------------------------------------------------
|
wolffd@0
|
304 % ------------------------------------------------------------------------------------------------
|
wolffd@0
|
305 disp('(If you don''t want to specify a test-set press ''return'' only)');
|
wolffd@0
|
306 test_path=input('Insert the complete (with extension) path of the test data file:\n >> ','s');
|
wolffd@0
|
307 if ~isempty(test_path),
|
wolffd@0
|
308 if ~isempty(findstr('.mat',test_path)),
|
wolffd@0
|
309 ap=load(test_path); app=fieldnames(ap); test_data=eval(['ap.', app{1,1}]);
|
wolffd@0
|
310 clear ap app;
|
wolffd@0
|
311 elseif ~isempty(findstr('.txt',test_path)),
|
wolffd@0
|
312 test_data=load(test_path, '-ascii');
|
wolffd@0
|
313 else
|
wolffd@0
|
314 error('Invalid data format: not a .mat or a .txt file')
|
wolffd@0
|
315 end
|
wolffd@0
|
316 if (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+res_dim)&(type==1),
|
wolffd@0
|
317 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
|
wolffd@0
|
318 num2str(cov_dim+res_dim), ' or ', num2str(cov_dim), '!']);
|
wolffd@0
|
319 elseif (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+1)&(type==2),
|
wolffd@0
|
320 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
|
wolffd@0
|
321 num2str(cov_dim+1), ' or ', num2str(cov_dim), '!']);
|
wolffd@0
|
322 elseif (~isempty(find(ismember(intersect([test_data(:,end)' 1:res_dim],...
|
wolffd@0
|
323 test_data(:,end)'),[1:res_dim])==0)))&(type==2)&(size(test_data,2)==cov_dim+1),
|
wolffd@0
|
324 error('Invalid class label');
|
wolffd@0
|
325 end
|
wolffd@0
|
326 ntest=size(test_data,1);
|
wolffd@0
|
327 test_d=test_data(:,1:cov_dim);
|
wolffd@0
|
328 if (type==2)&(size(test_data,2)>cov_dim),
|
wolffd@0
|
329 test_t=zeros(ntest, res_dim);
|
wolffd@0
|
330 for m=1:res_dim,
|
wolffd@0
|
331 test_t((find(test_data(:,end)==m))',m)=1;
|
wolffd@0
|
332 end
|
wolffd@0
|
333 elseif (type==1)&(size(test_data,2)>cov_dim),
|
wolffd@0
|
334 test_t=test_data(:,cov_dim+1:end);
|
wolffd@0
|
335 end
|
wolffd@0
|
336 disp(' ');
|
wolffd@0
|
337 end
|
wolffd@0
|
338 else
|
wolffd@0
|
339 clc
|
wolffd@0
|
340 disp('----------------------------------------------------');
|
wolffd@0
|
341 disp(' Specify the Input ');
|
wolffd@0
|
342 disp('----------------------------------------------------');
|
wolffd@0
|
343 disp(' ')
|
wolffd@0
|
344 ntrain = input('Insert the number of examples in training (<500): ');
|
wolffd@0
|
345 if (isempty(ntrain)|(floor(ntrain)~=ntrain)|(ntrain<=0)|(ntrain>500)),
|
wolffd@0
|
346 error(['Invalid value: ', num2str(ntrain), ' is not a positive integer <500!']);
|
wolffd@0
|
347 end
|
wolffd@0
|
348 disp(' ')
|
wolffd@0
|
349 test_path='toy';
|
wolffd@0
|
350 ntest = input('Insert the number of examples in test (<500): ');
|
wolffd@0
|
351 if (isempty(ntest)|(floor(ntest)~=ntest)|(ntest<=0)|(ntest>500)),
|
wolffd@0
|
352 error(['Invalid value: ', num2str(ntest), ' is not a positive integer <500!']);
|
wolffd@0
|
353 end
|
wolffd@0
|
354
|
wolffd@0
|
355 if type==2,
|
wolffd@0
|
356 cov_dim=2;
|
wolffd@0
|
357 res_dim=3;
|
wolffd@0
|
358 seed = 42;
|
wolffd@0
|
359 [train_d, ntrain1, ntrain2, train_t]=gen_data(ntrain, seed);
|
wolffd@0
|
360 for m=1:ntrain
|
wolffd@0
|
361 q=[]; q = find(train_t(m,:)==1);
|
wolffd@0
|
362 train_data(m,:)=[train_d(m,:) q];
|
wolffd@0
|
363 end
|
wolffd@0
|
364 [test_d, ntest1, ntest2, test_t]=gen_data(ntest);
|
wolffd@0
|
365 for m=1:ntest
|
wolffd@0
|
366 q=[]; q = find(test_t(m,:)==1);
|
wolffd@0
|
367 test_data(m,:)=[test_d(m,:) q];
|
wolffd@0
|
368 end
|
wolffd@0
|
369 else
|
wolffd@0
|
370 cov_dim=1;
|
wolffd@0
|
371 res_dim=1;
|
wolffd@0
|
372 global HOME
|
wolffd@0
|
373 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
374 load([HOME '/examples/static/Misc/mixexp_data.txt'], '-ascii');
|
wolffd@0
|
375 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
|
wolffd@0
|
376 train_data = mixexp_data(1:ntrain, :);
|
wolffd@0
|
377 train_d=train_data(:,1:cov_dim); train_t=train_data(:,cov_dim+1:end);
|
wolffd@0
|
378 test_data = mixexp_data(ntrain+1:ntrain+ntest, :);
|
wolffd@0
|
379 test_d=test_data(:,1:cov_dim);
|
wolffd@0
|
380 if size(test_data,2)>cov_dim,
|
wolffd@0
|
381 test_t=test_data(:,cov_dim+1:end);
|
wolffd@0
|
382 end
|
wolffd@0
|
383 end
|
wolffd@0
|
384 end
|
wolffd@0
|
385 % Set the nodes dimension-----------------------------------
|
wolffd@0
|
386 if num_glevel>0,
|
wolffd@0
|
387 nodes_info(2,2:num_glevel+1)=branch_fact;
|
wolffd@0
|
388 end
|
wolffd@0
|
389 nodes_info(2,1)=cov_dim; nodes_info(2,end)=res_dim;
|
wolffd@0
|
390 %-----------------------------------------------------------
|
wolffd@0
|
391 % Prepare the training data for the learning engine---------
|
wolffd@0
|
392 %-----------------------------------------------------------
|
wolffd@0
|
393 cases = cell(size(nodes_info,2), ntrain);
|
wolffd@0
|
394 for m=1:ntrain,
|
wolffd@0
|
395 cases{1,m}=train_data(m,1:cov_dim)';
|
wolffd@0
|
396 cases{end,m}=train_data(m,cov_dim+1:end)';
|
wolffd@0
|
397 end
|
wolffd@0
|
398 %-----------------------------------------------------------------------------------------------------
|
wolffd@0
|
399 [bnet onodes]=hme_topobuilder(nodes_info);
|
wolffd@0
|
400 engine = jtree_inf_engine(bnet, onodes);
|
wolffd@0
|
401 clc
|
wolffd@0
|
402 disp('---------------------------------------------------------------------');
|
wolffd@0
|
403 disp(' L E A R N I N G ');
|
wolffd@0
|
404 disp('---------------------------------------------------------------------');
|
wolffd@0
|
405 disp(' ')
|
wolffd@0
|
406 ll = 0;
|
wolffd@0
|
407 for l=1:ntrain
|
wolffd@0
|
408 scritta=['example number: ', int2str(l),'---------------------------------------------'];
|
wolffd@0
|
409 disp(scritta);
|
wolffd@0
|
410 ev = cases(:,l);
|
wolffd@0
|
411 [engine, loglik] = enter_evidence(engine, ev);
|
wolffd@0
|
412 ll = ll + loglik;
|
wolffd@0
|
413 end
|
wolffd@0
|
414 disp(' ')
|
wolffd@0
|
415 disp(['Log-likelihood before learning: ', num2str(ll)]);
|
wolffd@0
|
416 disp(' ')
|
wolffd@0
|
417 disp('(Press any key to continue)');
|
wolffd@0
|
418 pause
|
wolffd@0
|
419 %-----------------------------------------------------------
|
wolffd@0
|
420 clc
|
wolffd@0
|
421 disp('---------------------------------------------------------------------');
|
wolffd@0
|
422 disp(' L E A R N I N G ');
|
wolffd@0
|
423 disp('---------------------------------------------------------------------');
|
wolffd@0
|
424 disp(' ')
|
wolffd@0
|
425 max_em_iter=input('Insert the maximum number of the EM algorithm iterations: ');
|
wolffd@0
|
426 if (isempty(max_em_iter)|(floor(max_em_iter)~=max_em_iter)|(max_em_iter<=1)),
|
wolffd@0
|
427 error(['Invalid value: ', num2str(ntest), ' is not a positive integer >1!']);
|
wolffd@0
|
428 end
|
wolffd@0
|
429 disp(' ')
|
wolffd@0
|
430 disp(['Log-likelihood before learning: ', num2str(ll)]);
|
wolffd@0
|
431 disp(' ')
|
wolffd@0
|
432
|
wolffd@0
|
433 [bnet2, LL2] = learn_params_em(engine, cases, max_em_iter);
|
wolffd@0
|
434 disp(' ')
|
wolffd@0
|
435 fprintf('HME: loglik before learning %f, after %d iters %f\n', ll, length(LL2), LL2(end));
|
wolffd@0
|
436 disp(' ')
|
wolffd@0
|
437 disp('(Press any key to continue)');
|
wolffd@0
|
438 pause
|
wolffd@0
|
439 %-----------------------------------------------------------------------------------
|
wolffd@0
|
440 % Classification problem: plot data & decision boundaries if the input data size = 2
|
wolffd@0
|
441 % Regression problem: plot data & prediction if the input data size = 1
|
wolffd@0
|
442 %-----------------------------------------------------------------------------------
|
wolffd@0
|
443 if (type==2)&(nodes_info(2,1)==2)&(~isempty(test_path)),
|
wolffd@0
|
444 fh1=hme_class_plot(bnet2, nodes_info, train_data, test_data);
|
wolffd@0
|
445 disp(' ');
|
wolffd@0
|
446 disp('(See the figure)');
|
wolffd@0
|
447 elseif (type==2)&(nodes_info(2,1)==2)&(isempty(test_path)),
|
wolffd@0
|
448 fh1=hme_class_plot(bnet2, nodes_info, train_data);
|
wolffd@0
|
449 disp(' ');
|
wolffd@0
|
450 disp('(See the figure)');
|
wolffd@0
|
451 elseif (type==1)&(nodes_info(2,1)==1)&(~isempty(test_path)),
|
wolffd@0
|
452 fh1=hme_reg_plot(bnet2, nodes_info, train_data, test_data);
|
wolffd@0
|
453 disp(' ');
|
wolffd@0
|
454 disp('(See the figure)');
|
wolffd@0
|
455 elseif (type==1)&(nodes_info(2,1)==1)&(isempty(test_path)),
|
wolffd@0
|
456 fh1=hme_reg_plot(bnet2, nodes_info, train_data);
|
wolffd@0
|
457 disp(' ')
|
wolffd@0
|
458 disp('(See the figure)');
|
wolffd@0
|
459 end
|
wolffd@0
|
460 %-----------------------------------------------------------------------------------
|
wolffd@0
|
461 % Classification problem: plot confusion matrix
|
wolffd@0
|
462 %-----------------------------------------------------------------------------------
|
wolffd@0
|
463 if (type==2)
|
wolffd@0
|
464 ztrain=fhme(bnet2, nodes_info, train_d, size(train_d,1));
|
wolffd@0
|
465 [Htrain, trainRate]=confmat(ztrain, train_t); % CM on the training set
|
wolffd@0
|
466 fh2=figure('Name','Confusion matrix', 'MenuBar', 'none', 'NumberTitle', 'off');
|
wolffd@0
|
467 if (~isempty(test_path))&(size(test_data,2)>cov_dim),
|
wolffd@0
|
468 ztest=fhme(bnet2, nodes_info, test_d, size(test_d,1));
|
wolffd@0
|
469 [Htest, testRate]=confmat(ztest, test_t); % CM on the test set
|
wolffd@0
|
470 subplot(1,2,1);
|
wolffd@0
|
471 end
|
wolffd@0
|
472 plotmat(Htrain,'b','k',12)
|
wolffd@0
|
473 tick=[0.5:1:(0.5+nodes_info(2,end)-1)];
|
wolffd@0
|
474 set(gca,'XTick',tick)
|
wolffd@0
|
475 set(gca,'YTick',tick)
|
wolffd@0
|
476 grid('off')
|
wolffd@0
|
477 ylabel('True')
|
wolffd@0
|
478 xlabel('Prediction')
|
wolffd@0
|
479 title(['Confusion Matrix: training set (' num2str(trainRate(1)) '%)'])
|
wolffd@0
|
480 if (~isempty(test_path))&(size(test_data,2)>cov_dim),
|
wolffd@0
|
481 subplot(1,2,2)
|
wolffd@0
|
482 plotmat(Htest,'b','k',12)
|
wolffd@0
|
483 set(gca,'XTick',tick)
|
wolffd@0
|
484 set(gca,'YTick',tick)
|
wolffd@0
|
485 grid('off')
|
wolffd@0
|
486 ylabel('True')
|
wolffd@0
|
487 xlabel('Prediction')
|
wolffd@0
|
488 title(['Confusion Matrix: test set (' num2str(testRate(1)) '%)'])
|
wolffd@0
|
489 end
|
wolffd@0
|
490 disp(' ')
|
wolffd@0
|
491 disp('(Press any key to continue)');
|
wolffd@0
|
492 pause
|
wolffd@0
|
493 end
|
wolffd@0
|
494 %-----------------------------------------------------------------------------------
|
wolffd@0
|
495 % Regression & Classification problem: calculate the predictions & plot the LL trace
|
wolffd@0
|
496 %-----------------------------------------------------------------------------------
|
wolffd@0
|
497 train_result=fhme(bnet2,nodes_info,train_d,size(train_d,1));
|
wolffd@0
|
498 if ~isempty(test_path),
|
wolffd@0
|
499 test_result=fhme(bnet2,nodes_info,test_d,size(test_d,1));
|
wolffd@0
|
500 end
|
wolffd@0
|
501 fh3=figure('Name','Log-likelihood trace', 'MenuBar', 'none', 'NumberTitle', 'off')
|
wolffd@0
|
502 plot(LL2,'-ro',...
|
wolffd@0
|
503 'MarkerEdgeColor','k',...
|
wolffd@0
|
504 'MarkerFaceColor',[1 1 0],...
|
wolffd@0
|
505 'MarkerSize',4)
|
wolffd@0
|
506 title('Log-likelihood trace')
|
wolffd@0
|
507 %-----------------------------------------------------------------------------------
|
wolffd@0
|
508 % Regression & Classification problem: save the predictions
|
wolffd@0
|
509 %-----------------------------------------------------------------------------------
|
wolffd@0
|
510 clc
|
wolffd@0
|
511 disp('------------------------------------------------------------------');
|
wolffd@0
|
512 disp(' Save the results ');
|
wolffd@0
|
513 disp('------------------------------------------------------------------');
|
wolffd@0
|
514 disp(' ')
|
wolffd@0
|
515 %-----------------------------------------------------------------------------------
|
wolffd@0
|
516 save_quest_m=input('Do you want to save the HME model (Y/N)? [Y default]: ', 's');
|
wolffd@0
|
517 if isempty(save_quest_m),
|
wolffd@0
|
518 save_quest_m='Y';
|
wolffd@0
|
519 end
|
wolffd@0
|
520 if ~findstr(save_quest_m, ['Y', 'N']), error('Invalid input'); end
|
wolffd@0
|
521 if save_quest_m=='Y',
|
wolffd@0
|
522 disp(' ');
|
wolffd@0
|
523 m_save=input('Insert the complete path for save the HME model (.mat):\n >> ', 's');
|
wolffd@0
|
524 if isempty(m_save), error('You must specify a path!'); end
|
wolffd@0
|
525 save(m_save, 'bnet2');
|
wolffd@0
|
526 end
|
wolffd@0
|
527 %-----------------------------------------------------------------------------------
|
wolffd@0
|
528 disp(' ')
|
wolffd@0
|
529 save_quest=input('Do you want to save the HME predictions (Y/N)? [Y default]: ', 's');
|
wolffd@0
|
530 disp(' ')
|
wolffd@0
|
531 if isempty(save_quest),
|
wolffd@0
|
532 save_quest='Y';
|
wolffd@0
|
533 end
|
wolffd@0
|
534 if ~findstr(save_quest, ['Y', 'N']), error('Invalid input'); end
|
wolffd@0
|
535 if save_quest=='Y',
|
wolffd@0
|
536 tr_save=input('Insert the complete path for save the training data prediction (.mat):\n >> ', 's');
|
wolffd@0
|
537 if isempty(tr_save), error('You must specify a path!'); end
|
wolffd@0
|
538 save(tr_save, 'train_result');
|
wolffd@0
|
539 if ~isempty(test_path),
|
wolffd@0
|
540 disp(' ')
|
wolffd@0
|
541 te_save=input('Insert the complete path for save the test data prediction (.mat):\n >> ', 's');
|
wolffd@0
|
542 if isempty(te_save), error('You must specify a path!'); end
|
wolffd@0
|
543 save(te_save, 'test_result');
|
wolffd@0
|
544 end
|
wolffd@0
|
545 end
|
wolffd@0
|
546 clc
|
wolffd@0
|
547 disp('----------------------------------------------------');
|
wolffd@0
|
548 disp(' B Y E ! ');
|
wolffd@0
|
549 disp('----------------------------------------------------');
|
wolffd@0
|
550 pause(2)
|
wolffd@0
|
551 %clear
|
wolffd@0
|
552 clc
|