annotate toolboxes/FullBNT-1.0.7/bnt/examples/static/HME/hmemenu.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 % dataset -> (1=>user data) or (2=>toy example)
wolffd@0 2 % type -> (1=> Regression model) or (2=>Classification model)
wolffd@0 3 % num_glevel -> number of hidden nodes in the net (gating levels)
wolffd@0 4 % num_exp -> number of experts in the net
wolffd@0 5 % branch_fact -> dimension of the hidden nodes in the net
wolffd@0 6 % cov_dim -> root node dimension
wolffd@0 7 % res_dim -> output node dimension
wolffd@0 8 % nodes_info -> 4 x num_glevel+2 matrix that contain all the info about the nodes:
wolffd@0 9 % nodes_info(1,:) = nodes type: (0=>gaussian)or(1=>softmax)or(2=>mlp)
wolffd@0 10 % nodes_info(2,:) = nodes size: [cov_dim num_glevel x branch_fact res_dim]
wolffd@0 11 % nodes_info(3,:) = hidden units number (for mlp nodes)
wolffd@0 12 % |- optimizer iteration number (for softmax & mlp CPD)
wolffd@0 13 % nodes_info(4,:) =|- covariance type (for gaussian CPD)->
wolffd@0 14 % | (1=>Full)or(2=>Diagonal)or(3=>Full&Tied)or(4=>Diagonal&Tied)
wolffd@0 15 % fh1 -> Figure: data & decizion boundaries; fh2 -> confusion matrix; fh3 -> LL trace
wolffd@0 16 % test_data -> test data matrix
wolffd@0 17 % train_data -> training data matrix
wolffd@0 18 % ntrain -> size(train_data,2)
wolffd@0 19 % ntest -> size(test_data,2)
wolffd@0 20 % cases -> (cell array) training data formatted for the learning engine
wolffd@0 21 % bnet -> bayesian net before learning
wolffd@0 22 % bnet2 -> bayesian net after learning
wolffd@0 23 % ll -> log-likelihood before learning
wolffd@0 24 % LL2 -> log-likelihood trace
wolffd@0 25 % onodes -> obs nodes in bnet & bnet2
wolffd@0 26 % max_em_iter -> maximum number of interations of the EM algorithm
wolffd@0 27 % train_result -> prediction on the training set (as test_result)
wolffd@0 28 %
wolffd@0 29 % IMPORTANT: CHECK the loading path (lines 64 & 364)
wolffd@0 30 % ----------------------------------------------------------------------------------------------------
wolffd@0 31 % -> pierpaolo_b@hotmail.com or -> pampo@interfree.it
wolffd@0 32 % ----------------------------------------------------------------------------------------------------
wolffd@0 33
wolffd@0 34 error('this no longer works with the latest version of BNT')
wolffd@0 35
wolffd@0 36 clear all;
wolffd@0 37 clc;
wolffd@0 38 disp('---------------------------------------------------');
wolffd@0 39 disp(' Hierarchical Mixtures of Experts models builder ');
wolffd@0 40 disp('---------------------------------------------------');
wolffd@0 41 disp(' ')
wolffd@0 42 disp(' Using this script you can build both an HME model')
wolffd@0 43 disp('as in [Wat94] and [Jor94] i.e. with ''softmax'' gating')
wolffd@0 44 disp('nodes and ''gaussian'' ( for regression ) or ''softmax''')
wolffd@0 45 disp('( for classification ) expert node, and its variants')
wolffd@0 46 disp('called ''gated nets'' where we use ''mlp'' models in')
wolffd@0 47 disp('place of a number of ''softmax'' ones [Mor98], [Wei95].')
wolffd@0 48 disp(' You can decide to train and test the model on your')
wolffd@0 49 disp('datasets or to evaluate its performance on a toy')
wolffd@0 50 disp('example.')
wolffd@0 51 disp(' ')
wolffd@0 52 disp('Reference')
wolffd@0 53 disp('[Mor98] P. Moerland (1998):')
wolffd@0 54 disp(' Localized mixtures of experts. (http://www.idiap.ch/~perry/)')
wolffd@0 55 disp('[Jor94] M.I. Jordan, R.A. Jacobs (1994):')
wolffd@0 56 disp(' HME and the EM algorithm. (http://www.cs.berkeley.edu/~jordan/)')
wolffd@0 57 disp('[Wat94] S.R. Waterhouse, A.J. Robinson (1994):')
wolffd@0 58 disp(' Classification using HME. (http://www.oigeeza.com/steve/)')
wolffd@0 59 disp('[Wei95] A.S. Weigend, M. Mangeas (1995):')
wolffd@0 60 disp(' Nonlinear gated experts for time series.')
wolffd@0 61 disp(' ')
wolffd@0 62
wolffd@0 63 if 0
wolffd@0 64 disp('(See the figure)')
wolffd@0 65 pause(5);
wolffd@0 66 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 67 im_path=which('HMEforMatlab.jpg');
wolffd@0 68 fig=imread(im_path, 'jpg');
wolffd@0 69 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 70 figure('Units','pixels','MenuBar','none','NumberTitle','off', 'Name', 'HME model');
wolffd@0 71 image(fig);
wolffd@0 72 axis image;
wolffd@0 73 axis off;
wolffd@0 74 clear fig;
wolffd@0 75 set(gca,'Position',[0 0 1 1])
wolffd@0 76 disp('(Press any key to continue)')
wolffd@0 77 pause
wolffd@0 78 end
wolffd@0 79
wolffd@0 80 clc
wolffd@0 81 disp('---------------------------------------------------');
wolffd@0 82 disp(' Specify the Architecture ');
wolffd@0 83 disp('---------------------------------------------------');
wolffd@0 84 disp(' ');
wolffd@0 85 disp('What kind of model do you need?')
wolffd@0 86 disp(' ')
wolffd@0 87 disp('1) Regression ')
wolffd@0 88 disp('2) Classification')
wolffd@0 89 disp(' ')
wolffd@0 90 type=input('1 or 2?: ');
wolffd@0 91 if (isempty(type)|(~ismember(type,[1 2]))), error('Invalid value'); end
wolffd@0 92 clc
wolffd@0 93 disp('----------------------------------------------------');
wolffd@0 94 disp(' Specify the Architecture ');
wolffd@0 95 disp('----------------------------------------------------');
wolffd@0 96 disp(' ')
wolffd@0 97 disp('Now you have to set the number of experts and gating')
wolffd@0 98 disp('levels in the net. This script builds only balanced')
wolffd@0 99 disp('hierarchy with the same branching factor (>1)at each')
wolffd@0 100 disp('(gating) level. So remember that: ')
wolffd@0 101 disp(' ')
wolffd@0 102 disp(' num_exp = branch_fact^num_glevel ')
wolffd@0 103 disp(' ')
wolffd@0 104 disp('with branch_fact >=2.')
wolffd@0 105 disp('You can also set to zeros the number of gating level')
wolffd@0 106 disp('in order to obtain a classical GLM model. ')
wolffd@0 107 disp(' ')
wolffd@0 108 disp('----------------------------------------------------');
wolffd@0 109 disp(' ')
wolffd@0 110 num_glevel=input('Insert the number of gating levels {0,...,20}: ');
wolffd@0 111 if (isempty(num_glevel)|(~ismember(num_glevel,[0:20]))), error('Invalid value'); end
wolffd@0 112 nodes_info=zeros(4,num_glevel+2);
wolffd@0 113 if num_glevel>0, %------------------------------------------------------------------------------------
wolffd@0 114 for i=2:num_glevel+1,
wolffd@0 115 clc
wolffd@0 116 disp('----------------------------------------------------');
wolffd@0 117 disp(' Specify the Architecture ');
wolffd@0 118 disp('----------------------------------------------------');
wolffd@0 119 disp(' ')
wolffd@0 120 disp(['-> Gating network ', num2str(i-1), ' is a: '])
wolffd@0 121 disp(' ')
wolffd@0 122 disp(' 1) Softmax model');
wolffd@0 123 disp(' 2) Two layer perceptron model')
wolffd@0 124 disp(' ')
wolffd@0 125 nodes_info(1,i)=input('1 or 2?: ');
wolffd@0 126 if (isempty(nodes_info(1,i))|(~ismember(nodes_info(1,i),[1 2]))), error('Invalid value'); end
wolffd@0 127 disp(' ')
wolffd@0 128 if nodes_info(1,i)==2,
wolffd@0 129 nodes_info(3,i)=input('Insert the number of units in the hidden layer: ');
wolffd@0 130 if (isempty(nodes_info(3,i))|(floor(nodes_info(3,i))~=nodes_info(3,i))|(nodes_info(3,i)<=0)),
wolffd@0 131 error(['Invalid value: ', num2str(nodes_info(3,i)), ' is not a positive integer!']);
wolffd@0 132 end
wolffd@0 133 disp(' ')
wolffd@0 134 end
wolffd@0 135 nodes_info(4,i)=input('Insert the optimizer iteration number: ');
wolffd@0 136 if (isempty(nodes_info(4,i))|(floor(nodes_info(4,i))~=nodes_info(4,i))|(nodes_info(4,i)<=0)),
wolffd@0 137 error(['Invalid value: ', num2str(nodes_info(4,i)), ' is not a positive integer!']);
wolffd@0 138 end
wolffd@0 139 end
wolffd@0 140 clc
wolffd@0 141 disp('---------------------------------------------------------');
wolffd@0 142 disp(' Specify the Architecture ');
wolffd@0 143 disp('---------------------------------------------------------');
wolffd@0 144 disp(' ')
wolffd@0 145 disp('Now you have to set the number of experts in the network');
wolffd@0 146 disp('The value will be adjusted in order to obtain a hierarchy');
wolffd@0 147 disp('as said above.')
wolffd@0 148 disp(' ');
wolffd@0 149 num_exp=input(['Insert the approximative number of experts (>=', num2str(2^num_glevel), '): ']);
wolffd@0 150 if (isempty(num_exp)|(num_exp<=0)|(num_exp<2^num_glevel)),
wolffd@0 151 error('Invalid value');
wolffd@0 152 end
wolffd@0 153 app1=0; base=2;
wolffd@0 154 while app1<num_exp,
wolffd@0 155 app1=base^num_glevel;
wolffd@0 156 base=base+1;
wolffd@0 157 end
wolffd@0 158 app2=(base-2)^num_glevel;
wolffd@0 159 branch_fact=base-1;
wolffd@0 160 if app2>=(2^num_glevel)&(abs(app2-num_exp)<abs(app1-num_exp)),
wolffd@0 161 branch_fact=base-2;
wolffd@0 162 end
wolffd@0 163 clear app1 app2 base;
wolffd@0 164 disp(' ')
wolffd@0 165 disp(['The effective number of experts in the net is: ', num2str(branch_fact^num_glevel), '.'])
wolffd@0 166 disp(' ');
wolffd@0 167 else
wolffd@0 168 clc
wolffd@0 169 disp('---------------------------------------------------------');
wolffd@0 170 disp(' Specify the Architecture (GLM model) ');
wolffd@0 171 disp('---------------------------------------------------------');
wolffd@0 172 disp(' ')
wolffd@0 173 end % END of: if num_glevel>0-------------------------------------------------------------------------
wolffd@0 174
wolffd@0 175 if type==2,
wolffd@0 176 disp(['-> Expert node is a: '])
wolffd@0 177 disp(' ')
wolffd@0 178 disp(' 1) Softmax model');
wolffd@0 179 disp(' 2) Two layer perceptron model')
wolffd@0 180 disp(' ')
wolffd@0 181 nodes_info(1,end)=input('1 or 2?: ');
wolffd@0 182 if (isempty(nodes_info(1,end))|(~ismember(nodes_info(1,end),[1 2]))),
wolffd@0 183 error('Invalid value');
wolffd@0 184 end
wolffd@0 185 disp(' ')
wolffd@0 186 if nodes_info(1,end)==2,
wolffd@0 187 nodes_info(3,end)=input('Insert the number of units in the hidden layer: ');
wolffd@0 188 if (isempty(nodes_info(3,end))|(floor(nodes_info(3,end))~=nodes_info(3,end))|(nodes_info(3,end)<=0)),
wolffd@0 189 error(['Invalid value: ', num2str(nodes_info(3,end)), ' is not a positive integer!']);
wolffd@0 190 end
wolffd@0 191 disp(' ')
wolffd@0 192 end
wolffd@0 193 nodes_info(4,end)=input('Insert the optimizer iteration number: ');
wolffd@0 194 if (isempty(nodes_info(4,end))|(floor(nodes_info(4,end))~=nodes_info(4,end))|(nodes_info(4,end)<=0)),
wolffd@0 195 error(['Invalid value: ', num2str(nodes_info(4,end)), ' is not a positive integer!']);
wolffd@0 196 end
wolffd@0 197 elseif type==1,
wolffd@0 198 disp('What kind of covariance matrix structure do you want?')
wolffd@0 199 disp(' ')
wolffd@0 200 disp(' 1) Full');
wolffd@0 201 disp(' 2) Diagonal')
wolffd@0 202 disp(' 3) Full & Tied');
wolffd@0 203 disp(' 4) Diagonal & Tied')
wolffd@0 204
wolffd@0 205 disp(' ')
wolffd@0 206 nodes_info(4,end)=input('1, 2, 3 or 4?: ');
wolffd@0 207 if (isempty(nodes_info(4,end))|(~ismember(nodes_info(4,end),[1 2 3 4]))),
wolffd@0 208 error('Invalid value');
wolffd@0 209 end
wolffd@0 210 end
wolffd@0 211 clc
wolffd@0 212 disp('----------------------------------------------------');
wolffd@0 213 disp(' Specify the Input ');
wolffd@0 214 disp('----------------------------------------------------');
wolffd@0 215 disp(' ')
wolffd@0 216 disp('Do you want to...')
wolffd@0 217 disp(' ')
wolffd@0 218 disp('1) ...use your own dataset?')
wolffd@0 219 disp('2) ...apply the model on a toy example?')
wolffd@0 220 disp(' ')
wolffd@0 221 dataset=input('1 or 2?: ');
wolffd@0 222 if (isempty(dataset)|(~ismember(dataset,[1 2]))), error('Invalid value'); end
wolffd@0 223 if dataset==1,
wolffd@0 224 if type==1,
wolffd@0 225 clc
wolffd@0 226 disp('-------------------------------------------------------');
wolffd@0 227 disp(' Specify the Input - Regression problem ');
wolffd@0 228 disp('-------------------------------------------------------');
wolffd@0 229 disp(' ')
wolffd@0 230 disp('Be sure that each row of your data matrix is an example');
wolffd@0 231 disp('with the covariate values that precede the respond ones')
wolffd@0 232 disp(' ')
wolffd@0 233 disp('-------------------------------------------------------');
wolffd@0 234 disp(' ')
wolffd@0 235 cov_dim=input('Insert the covariate space dimension: ');
wolffd@0 236 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)),
wolffd@0 237 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']);
wolffd@0 238 end
wolffd@0 239 disp(' ')
wolffd@0 240 res_dim=input('Insert the dimension of the respond variable: ');
wolffd@0 241 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)),
wolffd@0 242 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']);
wolffd@0 243 end
wolffd@0 244 disp(' ');
wolffd@0 245 elseif type==2
wolffd@0 246 clc
wolffd@0 247 disp('-------------------------------------------------------');
wolffd@0 248 disp(' Specify the Input - Classification problem ');
wolffd@0 249 disp('-------------------------------------------------------');
wolffd@0 250 disp(' ')
wolffd@0 251 disp('Be sure that each row of your data matrix is an example');
wolffd@0 252 disp('with the covariate values that precede the class labels');
wolffd@0 253 disp('(integer value >=1). ');
wolffd@0 254 disp(' ')
wolffd@0 255 disp('-------------------------------------------------------');
wolffd@0 256 disp(' ')
wolffd@0 257 cov_dim=input('Insert the covariate space dimension: ');
wolffd@0 258 if (isempty(cov_dim)|(floor(cov_dim)~=cov_dim)|(cov_dim<=0)),
wolffd@0 259 error(['Invalid value: ', num2str(cov_dim), ' is not a positive integer!']);
wolffd@0 260 end
wolffd@0 261 disp(' ')
wolffd@0 262 res_dim=input('Insert the number of classes: ');
wolffd@0 263 if (isempty(res_dim)|(floor(res_dim)~=res_dim)|(res_dim<=0)),
wolffd@0 264 error(['Invalid value: ', num2str(res_dim), ' is not a positive integer!']);
wolffd@0 265 end
wolffd@0 266 disp(' ')
wolffd@0 267 end
wolffd@0 268 % ------------------------------------------------------------------------------------------------
wolffd@0 269 % Loading training data --------------------------------------------------------------------------
wolffd@0 270 % ------------------------------------------------------------------------------------------------
wolffd@0 271 train_path=input('Insert the complete (with extension) path of the training data file:\n >> ','s');
wolffd@0 272 if isempty(train_path), error('You must specify a data set for training!'); end
wolffd@0 273 if ~isempty(findstr('.mat',train_path)),
wolffd@0 274 ap=load(train_path); app=fieldnames(ap); train_data=eval(['ap.', app{1,1}]);
wolffd@0 275 clear ap app;
wolffd@0 276 elseif ~isempty(findstr('.txt',train_path)),
wolffd@0 277 train_data=load(train_path, '-ascii');
wolffd@0 278 else
wolffd@0 279 error('Invalid data format: not a .mat or a .txt file')
wolffd@0 280 end
wolffd@0 281 if (size(train_data,2)~=cov_dim+res_dim)&(type==1),
wolffd@0 282 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
wolffd@0 283 num2str(cov_dim+res_dim),'!']);
wolffd@0 284 elseif (size(train_data,2)~=cov_dim+1)&(type==2),
wolffd@0 285 error(['Invalid data matrix size: ', num2str(size(train_data,2)), ' columns rather than ',...
wolffd@0 286 num2str(cov_dim+1),'!']);
wolffd@0 287 elseif (~isempty(find(ismember(intersect([train_data(:,end)' 1:res_dim],...
wolffd@0 288 train_data(:,end)'),[1:res_dim])==0)))&(type==2),
wolffd@0 289 error('Invalid class label');
wolffd@0 290 end
wolffd@0 291 ntrain=size(train_data,1);
wolffd@0 292 train_d=train_data(:,1:cov_dim);
wolffd@0 293 if type==2,
wolffd@0 294 train_t=zeros(ntrain, res_dim);
wolffd@0 295 for m=1:res_dim,
wolffd@0 296 train_t((find(train_data(:,end)==m))',m)=1;
wolffd@0 297 end
wolffd@0 298 else
wolffd@0 299 train_t=train_data(:,cov_dim+1:end);
wolffd@0 300 end
wolffd@0 301 disp(' ')
wolffd@0 302 % ------------------------------------------------------------------------------------------------
wolffd@0 303 % Loading test data ------------------------------------------------------------------------------
wolffd@0 304 % ------------------------------------------------------------------------------------------------
wolffd@0 305 disp('(If you don''t want to specify a test-set press ''return'' only)');
wolffd@0 306 test_path=input('Insert the complete (with extension) path of the test data file:\n >> ','s');
wolffd@0 307 if ~isempty(test_path),
wolffd@0 308 if ~isempty(findstr('.mat',test_path)),
wolffd@0 309 ap=load(test_path); app=fieldnames(ap); test_data=eval(['ap.', app{1,1}]);
wolffd@0 310 clear ap app;
wolffd@0 311 elseif ~isempty(findstr('.txt',test_path)),
wolffd@0 312 test_data=load(test_path, '-ascii');
wolffd@0 313 else
wolffd@0 314 error('Invalid data format: not a .mat or a .txt file')
wolffd@0 315 end
wolffd@0 316 if (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+res_dim)&(type==1),
wolffd@0 317 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
wolffd@0 318 num2str(cov_dim+res_dim), ' or ', num2str(cov_dim), '!']);
wolffd@0 319 elseif (size(test_data,2)~=cov_dim)&(size(test_data,2)~=cov_dim+1)&(type==2),
wolffd@0 320 error(['Invalid data matrix size: ', num2str(size(test_data,2)), ' columns rather than ',...
wolffd@0 321 num2str(cov_dim+1), ' or ', num2str(cov_dim), '!']);
wolffd@0 322 elseif (~isempty(find(ismember(intersect([test_data(:,end)' 1:res_dim],...
wolffd@0 323 test_data(:,end)'),[1:res_dim])==0)))&(type==2)&(size(test_data,2)==cov_dim+1),
wolffd@0 324 error('Invalid class label');
wolffd@0 325 end
wolffd@0 326 ntest=size(test_data,1);
wolffd@0 327 test_d=test_data(:,1:cov_dim);
wolffd@0 328 if (type==2)&(size(test_data,2)>cov_dim),
wolffd@0 329 test_t=zeros(ntest, res_dim);
wolffd@0 330 for m=1:res_dim,
wolffd@0 331 test_t((find(test_data(:,end)==m))',m)=1;
wolffd@0 332 end
wolffd@0 333 elseif (type==1)&(size(test_data,2)>cov_dim),
wolffd@0 334 test_t=test_data(:,cov_dim+1:end);
wolffd@0 335 end
wolffd@0 336 disp(' ');
wolffd@0 337 end
wolffd@0 338 else
wolffd@0 339 clc
wolffd@0 340 disp('----------------------------------------------------');
wolffd@0 341 disp(' Specify the Input ');
wolffd@0 342 disp('----------------------------------------------------');
wolffd@0 343 disp(' ')
wolffd@0 344 ntrain = input('Insert the number of examples in training (<500): ');
wolffd@0 345 if (isempty(ntrain)|(floor(ntrain)~=ntrain)|(ntrain<=0)|(ntrain>500)),
wolffd@0 346 error(['Invalid value: ', num2str(ntrain), ' is not a positive integer <500!']);
wolffd@0 347 end
wolffd@0 348 disp(' ')
wolffd@0 349 test_path='toy';
wolffd@0 350 ntest = input('Insert the number of examples in test (<500): ');
wolffd@0 351 if (isempty(ntest)|(floor(ntest)~=ntest)|(ntest<=0)|(ntest>500)),
wolffd@0 352 error(['Invalid value: ', num2str(ntest), ' is not a positive integer <500!']);
wolffd@0 353 end
wolffd@0 354
wolffd@0 355 if type==2,
wolffd@0 356 cov_dim=2;
wolffd@0 357 res_dim=3;
wolffd@0 358 seed = 42;
wolffd@0 359 [train_d, ntrain1, ntrain2, train_t]=gen_data(ntrain, seed);
wolffd@0 360 for m=1:ntrain
wolffd@0 361 q=[]; q = find(train_t(m,:)==1);
wolffd@0 362 train_data(m,:)=[train_d(m,:) q];
wolffd@0 363 end
wolffd@0 364 [test_d, ntest1, ntest2, test_t]=gen_data(ntest);
wolffd@0 365 for m=1:ntest
wolffd@0 366 q=[]; q = find(test_t(m,:)==1);
wolffd@0 367 test_data(m,:)=[test_d(m,:) q];
wolffd@0 368 end
wolffd@0 369 else
wolffd@0 370 cov_dim=1;
wolffd@0 371 res_dim=1;
wolffd@0 372 global HOME
wolffd@0 373 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 374 load([HOME '/examples/static/Misc/mixexp_data.txt'], '-ascii');
wolffd@0 375 %%%%%WARNING!%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
wolffd@0 376 train_data = mixexp_data(1:ntrain, :);
wolffd@0 377 train_d=train_data(:,1:cov_dim); train_t=train_data(:,cov_dim+1:end);
wolffd@0 378 test_data = mixexp_data(ntrain+1:ntrain+ntest, :);
wolffd@0 379 test_d=test_data(:,1:cov_dim);
wolffd@0 380 if size(test_data,2)>cov_dim,
wolffd@0 381 test_t=test_data(:,cov_dim+1:end);
wolffd@0 382 end
wolffd@0 383 end
wolffd@0 384 end
wolffd@0 385 % Set the nodes dimension-----------------------------------
wolffd@0 386 if num_glevel>0,
wolffd@0 387 nodes_info(2,2:num_glevel+1)=branch_fact;
wolffd@0 388 end
wolffd@0 389 nodes_info(2,1)=cov_dim; nodes_info(2,end)=res_dim;
wolffd@0 390 %-----------------------------------------------------------
wolffd@0 391 % Prepare the training data for the learning engine---------
wolffd@0 392 %-----------------------------------------------------------
wolffd@0 393 cases = cell(size(nodes_info,2), ntrain);
wolffd@0 394 for m=1:ntrain,
wolffd@0 395 cases{1,m}=train_data(m,1:cov_dim)';
wolffd@0 396 cases{end,m}=train_data(m,cov_dim+1:end)';
wolffd@0 397 end
wolffd@0 398 %-----------------------------------------------------------------------------------------------------
wolffd@0 399 [bnet onodes]=hme_topobuilder(nodes_info);
wolffd@0 400 engine = jtree_inf_engine(bnet, onodes);
wolffd@0 401 clc
wolffd@0 402 disp('---------------------------------------------------------------------');
wolffd@0 403 disp(' L E A R N I N G ');
wolffd@0 404 disp('---------------------------------------------------------------------');
wolffd@0 405 disp(' ')
wolffd@0 406 ll = 0;
wolffd@0 407 for l=1:ntrain
wolffd@0 408 scritta=['example number: ', int2str(l),'---------------------------------------------'];
wolffd@0 409 disp(scritta);
wolffd@0 410 ev = cases(:,l);
wolffd@0 411 [engine, loglik] = enter_evidence(engine, ev);
wolffd@0 412 ll = ll + loglik;
wolffd@0 413 end
wolffd@0 414 disp(' ')
wolffd@0 415 disp(['Log-likelihood before learning: ', num2str(ll)]);
wolffd@0 416 disp(' ')
wolffd@0 417 disp('(Press any key to continue)');
wolffd@0 418 pause
wolffd@0 419 %-----------------------------------------------------------
wolffd@0 420 clc
wolffd@0 421 disp('---------------------------------------------------------------------');
wolffd@0 422 disp(' L E A R N I N G ');
wolffd@0 423 disp('---------------------------------------------------------------------');
wolffd@0 424 disp(' ')
wolffd@0 425 max_em_iter=input('Insert the maximum number of the EM algorithm iterations: ');
wolffd@0 426 if (isempty(max_em_iter)|(floor(max_em_iter)~=max_em_iter)|(max_em_iter<=1)),
wolffd@0 427 error(['Invalid value: ', num2str(ntest), ' is not a positive integer >1!']);
wolffd@0 428 end
wolffd@0 429 disp(' ')
wolffd@0 430 disp(['Log-likelihood before learning: ', num2str(ll)]);
wolffd@0 431 disp(' ')
wolffd@0 432
wolffd@0 433 [bnet2, LL2] = learn_params_em(engine, cases, max_em_iter);
wolffd@0 434 disp(' ')
wolffd@0 435 fprintf('HME: loglik before learning %f, after %d iters %f\n', ll, length(LL2), LL2(end));
wolffd@0 436 disp(' ')
wolffd@0 437 disp('(Press any key to continue)');
wolffd@0 438 pause
wolffd@0 439 %-----------------------------------------------------------------------------------
wolffd@0 440 % Classification problem: plot data & decision boundaries if the input data size = 2
wolffd@0 441 % Regression problem: plot data & prediction if the input data size = 1
wolffd@0 442 %-----------------------------------------------------------------------------------
wolffd@0 443 if (type==2)&(nodes_info(2,1)==2)&(~isempty(test_path)),
wolffd@0 444 fh1=hme_class_plot(bnet2, nodes_info, train_data, test_data);
wolffd@0 445 disp(' ');
wolffd@0 446 disp('(See the figure)');
wolffd@0 447 elseif (type==2)&(nodes_info(2,1)==2)&(isempty(test_path)),
wolffd@0 448 fh1=hme_class_plot(bnet2, nodes_info, train_data);
wolffd@0 449 disp(' ');
wolffd@0 450 disp('(See the figure)');
wolffd@0 451 elseif (type==1)&(nodes_info(2,1)==1)&(~isempty(test_path)),
wolffd@0 452 fh1=hme_reg_plot(bnet2, nodes_info, train_data, test_data);
wolffd@0 453 disp(' ');
wolffd@0 454 disp('(See the figure)');
wolffd@0 455 elseif (type==1)&(nodes_info(2,1)==1)&(isempty(test_path)),
wolffd@0 456 fh1=hme_reg_plot(bnet2, nodes_info, train_data);
wolffd@0 457 disp(' ')
wolffd@0 458 disp('(See the figure)');
wolffd@0 459 end
wolffd@0 460 %-----------------------------------------------------------------------------------
wolffd@0 461 % Classification problem: plot confusion matrix
wolffd@0 462 %-----------------------------------------------------------------------------------
wolffd@0 463 if (type==2)
wolffd@0 464 ztrain=fhme(bnet2, nodes_info, train_d, size(train_d,1));
wolffd@0 465 [Htrain, trainRate]=confmat(ztrain, train_t); % CM on the training set
wolffd@0 466 fh2=figure('Name','Confusion matrix', 'MenuBar', 'none', 'NumberTitle', 'off');
wolffd@0 467 if (~isempty(test_path))&(size(test_data,2)>cov_dim),
wolffd@0 468 ztest=fhme(bnet2, nodes_info, test_d, size(test_d,1));
wolffd@0 469 [Htest, testRate]=confmat(ztest, test_t); % CM on the test set
wolffd@0 470 subplot(1,2,1);
wolffd@0 471 end
wolffd@0 472 plotmat(Htrain,'b','k',12)
wolffd@0 473 tick=[0.5:1:(0.5+nodes_info(2,end)-1)];
wolffd@0 474 set(gca,'XTick',tick)
wolffd@0 475 set(gca,'YTick',tick)
wolffd@0 476 grid('off')
wolffd@0 477 ylabel('True')
wolffd@0 478 xlabel('Prediction')
wolffd@0 479 title(['Confusion Matrix: training set (' num2str(trainRate(1)) '%)'])
wolffd@0 480 if (~isempty(test_path))&(size(test_data,2)>cov_dim),
wolffd@0 481 subplot(1,2,2)
wolffd@0 482 plotmat(Htest,'b','k',12)
wolffd@0 483 set(gca,'XTick',tick)
wolffd@0 484 set(gca,'YTick',tick)
wolffd@0 485 grid('off')
wolffd@0 486 ylabel('True')
wolffd@0 487 xlabel('Prediction')
wolffd@0 488 title(['Confusion Matrix: test set (' num2str(testRate(1)) '%)'])
wolffd@0 489 end
wolffd@0 490 disp(' ')
wolffd@0 491 disp('(Press any key to continue)');
wolffd@0 492 pause
wolffd@0 493 end
wolffd@0 494 %-----------------------------------------------------------------------------------
wolffd@0 495 % Regression & Classification problem: calculate the predictions & plot the LL trace
wolffd@0 496 %-----------------------------------------------------------------------------------
wolffd@0 497 train_result=fhme(bnet2,nodes_info,train_d,size(train_d,1));
wolffd@0 498 if ~isempty(test_path),
wolffd@0 499 test_result=fhme(bnet2,nodes_info,test_d,size(test_d,1));
wolffd@0 500 end
wolffd@0 501 fh3=figure('Name','Log-likelihood trace', 'MenuBar', 'none', 'NumberTitle', 'off')
wolffd@0 502 plot(LL2,'-ro',...
wolffd@0 503 'MarkerEdgeColor','k',...
wolffd@0 504 'MarkerFaceColor',[1 1 0],...
wolffd@0 505 'MarkerSize',4)
wolffd@0 506 title('Log-likelihood trace')
wolffd@0 507 %-----------------------------------------------------------------------------------
wolffd@0 508 % Regression & Classification problem: save the predictions
wolffd@0 509 %-----------------------------------------------------------------------------------
wolffd@0 510 clc
wolffd@0 511 disp('------------------------------------------------------------------');
wolffd@0 512 disp(' Save the results ');
wolffd@0 513 disp('------------------------------------------------------------------');
wolffd@0 514 disp(' ')
wolffd@0 515 %-----------------------------------------------------------------------------------
wolffd@0 516 save_quest_m=input('Do you want to save the HME model (Y/N)? [Y default]: ', 's');
wolffd@0 517 if isempty(save_quest_m),
wolffd@0 518 save_quest_m='Y';
wolffd@0 519 end
wolffd@0 520 if ~findstr(save_quest_m, ['Y', 'N']), error('Invalid input'); end
wolffd@0 521 if save_quest_m=='Y',
wolffd@0 522 disp(' ');
wolffd@0 523 m_save=input('Insert the complete path for save the HME model (.mat):\n >> ', 's');
wolffd@0 524 if isempty(m_save), error('You must specify a path!'); end
wolffd@0 525 save(m_save, 'bnet2');
wolffd@0 526 end
wolffd@0 527 %-----------------------------------------------------------------------------------
wolffd@0 528 disp(' ')
wolffd@0 529 save_quest=input('Do you want to save the HME predictions (Y/N)? [Y default]: ', 's');
wolffd@0 530 disp(' ')
wolffd@0 531 if isempty(save_quest),
wolffd@0 532 save_quest='Y';
wolffd@0 533 end
wolffd@0 534 if ~findstr(save_quest, ['Y', 'N']), error('Invalid input'); end
wolffd@0 535 if save_quest=='Y',
wolffd@0 536 tr_save=input('Insert the complete path for save the training data prediction (.mat):\n >> ', 's');
wolffd@0 537 if isempty(tr_save), error('You must specify a path!'); end
wolffd@0 538 save(tr_save, 'train_result');
wolffd@0 539 if ~isempty(test_path),
wolffd@0 540 disp(' ')
wolffd@0 541 te_save=input('Insert the complete path for save the test data prediction (.mat):\n >> ', 's');
wolffd@0 542 if isempty(te_save), error('You must specify a path!'); end
wolffd@0 543 save(te_save, 'test_result');
wolffd@0 544 end
wolffd@0 545 end
wolffd@0 546 clc
wolffd@0 547 disp('----------------------------------------------------');
wolffd@0 548 disp(' B Y E ! ');
wolffd@0 549 disp('----------------------------------------------------');
wolffd@0 550 pause(2)
wolffd@0 551 %clear
wolffd@0 552 clc