Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@tree_CPD/learn_params.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function CPD = learn_params(CPD, fam, data, ns, cnodes, varargin) | |
2 % LEARN_PARAMS Construct classification/regression tree given complete data | |
3 % CPD = learn_params(CPD, fam, data, ns, cnodes) | |
4 % | |
5 % fam(i) is the node id of the i-th node in the family of nodes, self node is the last one | |
6 % data(i,m) is the value of node i in case m (can be cell array). | |
7 % ns(i) is the node size for the i-th node in the whold bnet | |
8 % cnodes(i) is the node id for the i-th continuous node in the whole bnet | |
9 % | |
10 % The following optional arguments can be specified in the form of name/value pairs: | |
11 % stop_cases: for early stop (pruning). A node is not split if it has less than k cases. default is 0. | |
12 % min_gain: for early stop (pruning). | |
13 % For discrete output: A node is not split when the gain of best split is less than min_gain. default is 0. | |
14 % For continuous (cts) outpt: A node is not split when the gain of best split is less than min_gain*score(root) | |
15 % (we denote it cts_min_gain). default is 0.006 | |
16 % %%%%%%%%%%%%%%%%%%%Struction definition of dtree_CPD.tree%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
17 % tree.num_node the last position in tree.nodes array for adding new nodes, | |
18 % it is not always same to number of nodes in a tree, because some position in the | |
19 % tree.nodes array can be set to unused (e.g. in tree pruning) | |
20 % tree.nodes is the array of nodes in the tree plus some unused nodes. | |
21 % tree.nodes(1) is the root for the tree. | |
22 % | |
23 % Below is the attributes for each node | |
24 % tree.nodes(i).used; % flag this node is used (0 means node not used, it can be removed from tree to save memory) | |
25 % tree.nodes(i).is_leaf; % if 1 means this node is a leaf, if 0 not a leaf. | |
26 % tree.nodes(i).children; % children(i) is the node number in tree.nodes array for the i-th child node | |
27 % tree.nodes(i).split_id; % the attribute id used to split this node | |
28 % tree.nodes(i).split_threshhold; % the threshhold for continuous attribute to split this node | |
29 % %%%%%attributes specially for classification tree (discrete output) | |
30 % tree.nodes(i).probs % probs(i) is the prob for i-th value of class node | |
31 % % For three output class, the probs = [0.9 0.1 0.0] means the probability of | |
32 % % class 1 is 0.9, for class 2 is 0.1, for class 3 is 0.0. | |
33 % %%%%%attributes specially for regression tree (continuous output) | |
34 % tree.nodes(i).mean % mean output value for this node | |
35 % tree.nodes(i).std % standard deviation for output values in this node | |
36 % | |
37 % Author: yimin.zhang@intel.com | |
38 % Last updated: Jan. 19, 2002 | |
39 | |
40 % Want list: | |
41 % (1) more efficient for cts attributes: get the values of cts attributes at first (the begining of build_tree function), then doing bi_search in finding threshhold | |
42 % (2) pruning classification tree using Pessimistic Error Pruning | |
43 % (3) bi_search for strings (used for transform data to BNT format) | |
44 | |
45 global tree %tree must be global so that it can be accessed in recursive slitting function | |
46 global cts_min_gain | |
47 tree=[]; % clear the tree | |
48 tree.num_node=0; | |
49 cts_min_gain=0; | |
50 | |
51 stop_cases=0; | |
52 min_gain=0; | |
53 | |
54 args = varargin; | |
55 nargs = length(args); | |
56 if (nargs>0) | |
57 if isstr(args{1}) | |
58 for i=1:2:nargs | |
59 switch args{i}, | |
60 case 'stop_cases', stop_cases = args{i+1}; | |
61 case 'min_gain', min_gain = args{i+1}; | |
62 end | |
63 end | |
64 else | |
65 error(['error in input parameters']); | |
66 end | |
67 end | |
68 | |
69 if iscell(data) | |
70 local_data = cell2num(data(fam,:)); | |
71 else | |
72 local_data = data(fam, :); | |
73 end | |
74 %counts = compute_counts(local_data, CPD.sizes); | |
75 %CPD.CPT = mk_stochastic(counts + CPD.prior); % bug fix 11/5/01 | |
76 node_types = zeros(1,size(ns,2)); %all nodes are disrete | |
77 node_types(cnodes)=1; | |
78 %make the data be BNT compliant (values for discrete nodes are from 1-n, here n is the node size) | |
79 %trans_data=transform_data(local_data,'tmp.dat',[]); %here no cts nodes | |
80 | |
81 build_dtree (CPD, local_data, ns(fam), node_types(fam),stop_cases,min_gain); | |
82 %CPD.tree=copy_tree(tree); | |
83 CPD.tree=tree; %copy the tree constructed to CPD | |
84 | |
85 | |
86 function new_tree = copy_tree(tree) | |
87 % copy the tree to new_tree | |
88 new_tree.num_node=tree.num_node; | |
89 new_tree.root = tree.root; | |
90 for i=1:tree.num_node | |
91 new_tree.nodes(i)=tree.nodes(i); | |
92 end | |
93 | |
94 | |
95 function build_dtree (CPD, fam_ev, node_sizes, node_types,stop_cases,min_gain) | |
96 global tree | |
97 global cts_min_gain | |
98 | |
99 tree.num_node=0; %the current number of nodes in the tree | |
100 tree.root=1; | |
101 | |
102 T = 1:size(fam_ev,2) ; %all cases | |
103 candidate_attrs = 1:(size(node_sizes,2)-1); %all attributes | |
104 node_id=1; %the root node | |
105 lastnode=size(node_sizes,2); %the last element in all nodes is the dependent variable (category node) | |
106 num_cat=node_sizes(lastnode); | |
107 | |
108 % get minimum gain for cts output (used in stop splitting) | |
109 if (node_types(size(fam_ev,1))==1) %cts output | |
110 N = size(fam_ev,2); | |
111 output_id = size(fam_ev,1); | |
112 cases_T = fam_ev(output_id,:); %get all the output value for cases T | |
113 std_T = std(cases_T); | |
114 avg_y_T = mean(cases_T); | |
115 sqr_T = cases_T - avg_y_T; | |
116 cts_min_gain = min_gain*(sum(sqr_T.*sqr_T)/N); % min_gain * (R(root) = 1/N * SUM(y-avg_y)^2) | |
117 end | |
118 | |
119 split_dtree (CPD, fam_ev, node_sizes, node_types, stop_cases,min_gain, T, candidate_attrs, num_cat); | |
120 | |
121 | |
122 | |
123 % pruning method | |
124 % (1) Restrictions on minimum node size: A node is not split if it has smaller than k cases. | |
125 % (2) Threshholds on impurity: a threshhold is imposed on the splitting test score. Threshhold can be | |
126 % imposed on local goodness measure (the gain_ratio of a node) or global goodness. | |
127 % (3) Mininum Error Pruning (MEP), (no need pruning set) | |
128 % Prune if static error<=backed-up error | |
129 % Static error at node v: e(v) = (Nc + 1)/(N+k) (laplace estimate, prior for each class equal) | |
130 % here N is # of all examples, Nc is # of majority class examples, k is number of classes | |
131 % Backed-up error at node v: (Ti is the i-th subtree root) | |
132 % E(T) = Sum_1_to_n(pi*e(Ti)) | |
133 % (4) Pessimistic Error Pruning (PEP), used in Quilan C4.5 (no need pruning set, efficient because of pruning top-down) | |
134 % Probability of error (apparent error rate) | |
135 % q = (N-Nc+0.5)/N | |
136 % where N=#examples, Nc=#examples in majority class | |
137 % Error of a node v (if pruned) q(v)= (Nv- Nc,v + 0.5)/Nv | |
138 % Error of a subtree q(T)= Sum_of_l_leaves(Nl - Nc,l + 0.5)/Sum_of_l_leaves(Nl) | |
139 % Prune if q(v)<=q(T) | |
140 % | |
141 % Implementation statuts: | |
142 % (1)(2) has been implemented as the input parameters of learn_params. | |
143 % (4) is implemented in this function | |
144 function pruning(fam_ev,node_sizes,node_types) | |
145 % PRUNING prune the constructed tree using PEP | |
146 % pruning(fam_ev,node_sizes,node_types) | |
147 % | |
148 % fam_ev(i,j) is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev) | |
149 % node_sizes(i) is the node size for the i-th node in the family | |
150 % node_types(i) is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node | |
151 % the global parameter 'tree' is for storing the input tree and the pruned tree | |
152 | |
153 | |
154 function split_T = split_cases(fam_ev,node_sizes,node_types,T,node_i, threshhold) | |
155 % SPLIT_CASES split the cases T according to values of node_i in the family | |
156 % split_T = split_cases(fam_ev,node_sizes,node_types,T,node_i) | |
157 % | |
158 % fam_ev(i,j) is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev) | |
159 % node_sizes(i) is the node size for the i-th node in the family | |
160 % node_types(i) is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node | |
161 % node_i is the attribute we need to split | |
162 | |
163 if (node_types(node_i)==0) %discrete attribute | |
164 %init the subsets of T | |
165 split_T = cell(1,node_sizes(node_i)); %T will be separated into |node_size of i| subsets according to different values of node i | |
166 for i=1:node_sizes(node_i) % here we assume that the value of an attribute is 1:node_size | |
167 split_T{i}=zeros(1,0); | |
168 end | |
169 | |
170 size_t = size(T,2); | |
171 for i=1:size_t | |
172 case_id = T(i); | |
173 %put this case into one subset of split_T according to its value for node_i | |
174 value = fam_ev(node_i,case_id); | |
175 pos = size(split_T{value},2)+1; | |
176 split_T{value}(pos)=case_id; % here assumes the value of an attribute is 1:node_size | |
177 end | |
178 else %continuous attribute | |
179 %init the subsets of T | |
180 split_T = cell(1,2); %T will be separated into 2 subsets (<=threshhold) (>threshhold) | |
181 for i=1:2 | |
182 split_T{i}=zeros(1,0); | |
183 end | |
184 | |
185 size_t = size(T,2); | |
186 for i=1:size_t | |
187 case_id = T(i); | |
188 %put this case into one subset of split_T according to its value for node_i | |
189 value = fam_ev(node_i,case_id); | |
190 subset_num=1; | |
191 if (value>threshhold) | |
192 subset_num=2; | |
193 end | |
194 pos = size(split_T{subset_num},2)+1; | |
195 split_T{subset_num}(pos)=case_id; | |
196 end | |
197 end | |
198 | |
199 | |
200 | |
201 function new_node = split_dtree (CPD, fam_ev, node_sizes, node_types, stop_cases, min_gain, T, candidate_attrs, num_cat) | |
202 % SPLIT_TREE Split the tree at node node_id with cases T (actually it is just indexes to family evidences). | |
203 % new_node = split_dtree (fam_ev, node_sizes, node_types, T, node_id, num_cat, method) | |
204 % | |
205 % fam_ev(i,j) is the value of attribute i in j-th training cases (for whole tree), the last row is for the class label (self_ev) | |
206 % node_sizes{i} is the node size for the i-th node in the family | |
207 % node_types{i} is the node type for the i-th node in the family, 0 for disrete node, 1 for continous node | |
208 % stop_cases is the threshold of number of cases to stop slitting | |
209 % min_gain is the minimum gain need to split a node | |
210 % T(i) is the index of i-th cases in current decision tree node, we need split it further | |
211 % candidate_attrs(i) the node id for the i-th attribute that still need to be considered as split attribute | |
212 %%%%% node_id is the index of current node considered for a split | |
213 % num_cat is the number of output categories for the decision tree | |
214 % output: | |
215 % new_node is the new node created | |
216 global tree | |
217 global cts_min_gain | |
218 | |
219 size_fam = size(fam_ev,1); %number of family size | |
220 output_type = node_types(size_fam); %the type of output for the tree (0 is discrete, 1 is continuous) | |
221 size_attrs = size(candidate_attrs,2); %number of candidate attributes | |
222 size_t = size(T,2); %number of training cases in this tree node | |
223 | |
224 %(1)computeFrequenceyForEachClass(T) | |
225 if (output_type==0) %discrete output | |
226 class_freqs = zeros(1,num_cat); | |
227 for i=1:size_t | |
228 case_id = T(i); | |
229 case_class = fam_ev(size_fam,case_id); %get the class label for this case | |
230 class_freqs(case_class)=class_freqs(case_class)+1; | |
231 end | |
232 else %cts output | |
233 N = size(fam_ev,2); | |
234 cases_T = fam_ev(size(fam_ev,1),T); %get the output value for cases T | |
235 std_T = std(cases_T); | |
236 end | |
237 | |
238 %(2) if OneClass (for discrete output) or same output value (for cts output) or Class With #examples < stop_cases | |
239 % return a leaf; | |
240 % create a decision node N; | |
241 | |
242 % get majority class in this node | |
243 if (output_type == 0) | |
244 top1_class = 0; %the class with the largest number of cases | |
245 top1_class_cases = 0; %the number of cases in top1_class | |
246 [top1_class_cases,top1_class]=max(class_freqs); | |
247 end | |
248 | |
249 if (size_t==0) %impossble | |
250 new_node=-1; | |
251 fprintf('Fatal error: please contact the author. \n'); | |
252 return; | |
253 end | |
254 | |
255 % stop splitting if needed | |
256 %for discrete output: one class | |
257 %for cts output, all output value in cases are same | |
258 %cases too little | |
259 if ( (output_type==0 & top1_class_cases == size_t) | (output_type==1 & std_T == 0) | (size_t < stop_cases)) | |
260 %create one new leaf node | |
261 tree.num_node=tree.num_node+1; | |
262 tree.nodes(tree.num_node).used=1; %flag this node is used (0 means node not used, it will be removed from tree at last to save memory) | |
263 tree.nodes(tree.num_node).is_leaf=1; | |
264 tree.nodes(tree.num_node).children=[]; | |
265 tree.nodes(tree.num_node).split_id=0; %the attribute(parent) id to split this tree node | |
266 tree.nodes(tree.num_node).split_threshhold=0; | |
267 if (output_type==0) | |
268 tree.nodes(tree.num_node).probs=class_freqs/size_t; %the prob for each value of class node | |
269 | |
270 % tree.nodes(tree.num_node).probs=zeros(1,num_cat); %the prob for each value of class node | |
271 % tree.nodes(tree.num_node).probs(top1_class)=1; %use the majority class of parent node, like for binary class, | |
272 %and majority is class 2, then the CPT is [0 1] | |
273 %we may need to use prior to do smoothing, to get [0.001 0.999] | |
274 tree.nodes(tree.num_node).error.self_error=1-top1_class_cases/size_t; %the classfication error in this tree node when use default class | |
275 tree.nodes(tree.num_node).error.all_error=1-top1_class_cases/size_t; %no total classfication error in this tree node and its subtree | |
276 tree.nodes(tree.num_node).error.all_error_num=size_t - top1_class_cases; | |
277 fprintf('Create leaf node(onecla) %d. Class %d Cases %d Error %d \n',tree.num_node, top1_class, size_t, size_t - top1_class_cases ); | |
278 else | |
279 avg_y_T = mean(cases_T); | |
280 tree.nodes(tree.num_node).mean = avg_y_T; | |
281 tree.nodes(tree.num_node).std = std_T; | |
282 fprintf('Create leaf node(samevalue) %d. Mean %8.4f Std %8.4f Cases %d \n',tree.num_node, avg_y_T, std_T, size_t); | |
283 end | |
284 new_node = tree.num_node; | |
285 return; | |
286 end | |
287 | |
288 %create one new node | |
289 tree.num_node=tree.num_node+1; | |
290 tree.nodes(tree.num_node).used=1; %flag this node is used (0 means node not used, it will be removed from tree at last to save memory) | |
291 tree.nodes(tree.num_node).is_leaf=1; | |
292 tree.nodes(tree.num_node).children=[]; | |
293 tree.nodes(tree.num_node).split_id=0; | |
294 tree.nodes(tree.num_node).split_threshhold=0; | |
295 if (output_type==0) | |
296 tree.nodes(tree.num_node).error.self_error=1-top1_class_cases/size_t; | |
297 tree.nodes(tree.num_node).error.all_error=0; | |
298 tree.nodes(tree.num_node).error.all_error_num=0; | |
299 else | |
300 avg_y_T = mean(cases_T); | |
301 tree.nodes(tree.num_node).mean = avg_y_T; | |
302 tree.nodes(tree.num_node).std = std_T; | |
303 end | |
304 new_node = tree.num_node; | |
305 | |
306 %Stop splitting if no attributes left in this node | |
307 if (size_attrs==0) | |
308 if (output_type==0) | |
309 tree.nodes(tree.num_node).probs=class_freqs/size_t; %the prob for each value of class node | |
310 tree.nodes(tree.num_node).error.all_error=1-top1_class_cases/size_t; | |
311 tree.nodes(tree.num_node).error.all_error_num=size_t - top1_class_cases; | |
312 fprintf('Create leaf node(noattr) %d. Class %d Cases %d Error %d \n',tree.num_node, top1_class, size_t, size_t - top1_class_cases ); | |
313 else | |
314 fprintf('Create leaf node(noattr) %d. Mean %8.4f Std %8.4f Cases %d \n',tree.num_node, avg_y_T, std_T, size_t); | |
315 end | |
316 return; | |
317 end | |
318 | |
319 | |
320 %(3) for each attribute A | |
321 % ComputeGain(A); | |
322 max_gain=0; %the max gain score (for discrete information gain or gain ration, for cts node the R(T)) | |
323 best_attr=0; %the attribute with the max_gain | |
324 best_split = []; %the split of T according to the value of best_attr | |
325 cur_best_threshhold = 0; %the threshhold for split continuous attribute | |
326 best_threshhold=0; | |
327 | |
328 % compute Info(T) (for discrete output) | |
329 if (output_type == 0) | |
330 class_split_T = split_cases(fam_ev,node_sizes,node_types,T,size(fam_ev,1),0); %split cases according to class | |
331 info_T = compute_info (fam_ev, T, class_split_T); | |
332 else % compute R(T) (for cts output) | |
333 % N = size(fam_ev,2); | |
334 % cases_T = fam_ev(size(fam_ev,1),T); %get the output value for cases T | |
335 % std_T = std(cases_T); | |
336 % avg_y_T = mean(cases_T); | |
337 sqr_T = cases_T - avg_y_T; | |
338 R_T = sum(sqr_T.*sqr_T)/N; % get R(T) = 1/N * SUM(y-avg_y)^2 | |
339 info_T = R_T; | |
340 end | |
341 | |
342 for i=1:(size_fam-1) | |
343 if (myismember(i,candidate_attrs)) %if this attribute still in the candidate attribute set | |
344 if (node_types(i)==0) %discrete attibute | |
345 split_T = split_cases(fam_ev,node_sizes,node_types,T,i,0); %split cases according to value of attribute i | |
346 % For cts output, we compute the least square gain. | |
347 % For discrete output, we compute gain ratio | |
348 cur_gain = compute_gain(fam_ev,node_sizes,node_types,T,info_T,i,split_T,0,output_type); %gain ratio | |
349 else %cts attribute | |
350 %get the values of this attribute | |
351 ev = fam_ev(:,T); | |
352 values = ev(i,:); | |
353 sort_v = sort(values); | |
354 %remove the duplicate values in sort_v | |
355 v_set = unique(sort_v); | |
356 best_gain = 0; | |
357 best_threshhold = 0; | |
358 best_split1 = []; | |
359 | |
360 %find the best split for this cts attribute | |
361 % see "Quilan 96: Improved Use of Continuous Attributes in C4.5" | |
362 for j=1:(size(v_set,2)-1) | |
363 mid_v = (v_set(j)+v_set(j+1))/2; | |
364 split_T = split_cases(fam_ev,node_sizes,node_types,T,i,mid_v); %split cases according to value of attribute i (<=mid_v) | |
365 % For cts output, we compute the least square gain. | |
366 % For discrete output, we use Quilan 96: use information gain instead of gain ratio to select threshhold | |
367 cur_gain = compute_gain(fam_ev,node_sizes,node_types,T,info_T,i,split_T,1,output_type); | |
368 %if (i==6) | |
369 % fprintf('gain %8.5f threshhold %6.3f spliting %d\n', cur_gain, mid_v, size(split_T{1},2)); | |
370 %end | |
371 | |
372 if (best_gain < cur_gain) | |
373 best_gain = cur_gain; | |
374 best_threshhold = mid_v; | |
375 %best_split1 = split_T; %here we need to copy array, not good!!! (maybe we can compute after we get best_attr | |
376 end | |
377 end | |
378 %recalculate the gain_ratio of the best_threshhold | |
379 split_T = split_cases(fam_ev,node_sizes,node_types,T,i,best_threshhold); | |
380 best_gain = compute_gain(fam_ev,node_sizes,node_types,T,info_T,i,split_T,0,output_type); %gain_ratio | |
381 if (output_type==0) %for discrete output | |
382 cur_gain = best_gain-log2(size(v_set,2)-1)/size_t; % Quilan 96: use the gain_ratio-log2(N-1)/|D| as the gain of this attr | |
383 else %for cts output | |
384 cur_gain = best_gain; | |
385 end | |
386 end | |
387 | |
388 if (max_gain < cur_gain) | |
389 max_gain = cur_gain; | |
390 best_attr = i; | |
391 cur_best_threshhold=best_threshhold; %save the threshhold | |
392 %best_split = split_T; %here we need to copy array, not good!!! So we will recalculate in below line 313 | |
393 end | |
394 end | |
395 end | |
396 | |
397 % stop splitting if gain is too small | |
398 if (max_gain==0 | (output_type==0 & max_gain < min_gain) | (output_type==1 & max_gain < cts_min_gain)) | |
399 if (output_type==0) | |
400 tree.nodes(tree.num_node).probs=class_freqs/size_t; %the prob for each value of class node | |
401 tree.nodes(tree.num_node).error.all_error=1-top1_class_cases/size_t; | |
402 tree.nodes(tree.num_node).error.all_error_num=size_t - top1_class_cases; | |
403 fprintf('Create leaf node(nogain) %d. Class %d Cases %d Error %d \n',tree.num_node, top1_class, size_t, size_t - top1_class_cases ); | |
404 else | |
405 fprintf('Create leaf node(nogain) %d. Mean %8.4f Std %8.4f Cases %d \n',tree.num_node, avg_y_T, std_T, size_t); | |
406 end | |
407 return; | |
408 end | |
409 | |
410 %get the split of cases according to the best split attribute | |
411 if (node_types(best_attr)==0) %discrete attibute | |
412 best_split = split_cases(fam_ev,node_sizes,node_types,T,best_attr,0); | |
413 else | |
414 best_split = split_cases(fam_ev,node_sizes,node_types,T,best_attr,cur_best_threshhold); | |
415 end | |
416 | |
417 %(4) best_attr = AttributeWithBestGain; | |
418 %(5) if best_attr is continuous ???? why need this? maybe the value in the decision tree must appeared in data | |
419 % find threshhold in all cases that <= max_V | |
420 % change the split of T | |
421 tree.nodes(tree.num_node).split_id=best_attr; | |
422 tree.nodes(tree.num_node).split_threshhold=cur_best_threshhold; %for cts attribute only | |
423 | |
424 %note: below threshhold rejust is linera search, so it is slow. A better method is described in paper "Efficient C4.5" | |
425 %if (output_type==0) | |
426 if (node_types(best_attr)==1) %is a continuous attribute | |
427 %find the value that approximate best_threshhold from below (the largest that <= best_threshhold) | |
428 best_value=0; | |
429 for i=1:size(fam_ev,2) %note: need to search in all cases for all tree, not just in cases for this node | |
430 val = fam_ev(best_attr,i); | |
431 if (val <= cur_best_threshhold & val > best_value) %val is more clear to best_threshhold | |
432 best_value=val; | |
433 end | |
434 end | |
435 tree.nodes(tree.num_node).split_threshhold=best_value; %for cts attribute only | |
436 end | |
437 %end | |
438 | |
439 if (output_type == 0) | |
440 fprintf('Create node %d split at %d gain %8.4f Th %d. Class %d Cases %d Error %d \n',tree.num_node, best_attr, max_gain, tree.nodes(tree.num_node).split_threshhold, top1_class, size_t, size_t - top1_class_cases ); | |
441 else | |
442 fprintf('Create node %d split at %d gain %8.4f Th %d. Mean %8.4f Cases %d\n',tree.num_node, best_attr, max_gain, tree.nodes(tree.num_node).split_threshhold, avg_y_T, size_t ); | |
443 end | |
444 | |
445 %(6) Foreach T' in the split_T | |
446 % if T' is Empty | |
447 % Child of node_id is a leaf | |
448 % else | |
449 % Child of node_id = split_tree (T') | |
450 tree.nodes(new_node).is_leaf=0; %because this node will be split, it is not leaf now | |
451 for i=1:size(best_split,2) | |
452 if (size(best_split{i},2)==0) %T(i) is empty | |
453 %create one new leaf node | |
454 tree.num_node=tree.num_node+1; | |
455 tree.nodes(tree.num_node).used=1; %flag this node is used (0 means node not used, it will be removed from tree at last to save memory) | |
456 tree.nodes(tree.num_node).is_leaf=1; | |
457 tree.nodes(tree.num_node).children=[]; | |
458 tree.nodes(tree.num_node).split_id=0; | |
459 tree.nodes(tree.num_node).split_threshhold=0; | |
460 if (output_type == 0) | |
461 tree.nodes(tree.num_node).probs=zeros(1,num_cat); %the prob for each value of class node | |
462 tree.nodes(tree.num_node).probs(top1_class)=1; %use the majority class of parent node, like for binary class, | |
463 %and majority is class 2, then the CPT is [0 1] | |
464 %we may need to use prior to do smoothing, to get [0.001 0.999] | |
465 tree.nodes(tree.num_node).error.self_error=0; | |
466 tree.nodes(tree.num_node).error.all_error=0; | |
467 tree.nodes(tree.num_node).error.all_error_num=0; | |
468 else | |
469 tree.nodes(tree.num_node).mean = avg_y_T; %just use parent node's mean value | |
470 tree.nodes(tree.num_node).std = std_T; | |
471 end | |
472 %add the new leaf node to parents | |
473 num_children=size(tree.nodes(new_node).children,2); | |
474 tree.nodes(new_node).children(num_children+1)=tree.num_node; | |
475 if (output_type==0) | |
476 fprintf('Create leaf node(nullset) %d. %d-th child of Father %d Class %d\n',tree.num_node, i, new_node, top1_class ); | |
477 else | |
478 fprintf('Create leaf node(nullset) %d. %d-th child of Father %d \n',tree.num_node, i, new_node ); | |
479 end | |
480 | |
481 else | |
482 if (node_types(best_attr)==0) % if attr is discrete, it should be removed from the candidate set | |
483 new_candidate_attrs = mysetdiff(candidate_attrs,[best_attr]); | |
484 else | |
485 new_candidate_attrs = candidate_attrs; | |
486 end | |
487 new_sub_node = split_dtree (CPD, fam_ev, node_sizes, node_types, stop_cases, min_gain, best_split{i}, new_candidate_attrs, num_cat); | |
488 %tree.nodes(parent_id).error.all_error += tree.nodes(new_sub_node).error.all_error; | |
489 fprintf('Add subtree node %d to %d. #nodes %d\n',new_sub_node,new_node, tree.num_node ); | |
490 | |
491 % tree.nodes(new_node).error.all_error_num = tree.nodes(new_node).error.all_error_num + tree.nodes(new_sub_node).error.all_error_num; | |
492 %add the new leaf node to parents | |
493 num_children=size(tree.nodes(new_node).children,2); | |
494 tree.nodes(new_node).children(num_children+1)=new_sub_node; | |
495 end | |
496 end | |
497 | |
498 %(7) Compute errors of N; for doing pruning | |
499 % get the total error for the subtree | |
500 if (output_type==0) | |
501 tree.nodes(new_node).error.all_error=tree.nodes(new_node).error.all_error_num/size_t; | |
502 end | |
503 %doing pruning, but doing here is not so efficient, because it is bottom up. | |
504 %if tree.nodes() | |
505 %after doing pruning, need to update the all_error to self_error | |
506 | |
507 %(8) Return N | |
508 | |
509 | |
510 | |
511 | |
512 %(1) For discrete output, we use GainRatio defined as below | |
513 % Gain(X,T) | |
514 % GainRatio(X,T) = ---------- | |
515 % SplitInfo(X,T) | |
516 % where | |
517 % Gain(X,T) = Info(T) - Info(X,T) | |
518 % |Ti| | |
519 % Info(X,T) = Sum for i from 1 to n of ( ---- * Info(Ti)) | |
520 % |T| | |
521 | |
522 % SplitInfo(D,T) is the information due to the split of T on the basis | |
523 % of the value of the categorical attribute D. Thus SplitInfo(D,T) is | |
524 % I(|T1|/|T|, |T2|/|T|, .., |Tm|/|T|) | |
525 % where {T1, T2, .. Tm} is the partition of T induced by the value of D. | |
526 | |
527 % Definition of Info(Ti) | |
528 % If a set T of records is partitioned into disjoint exhaustive classes C1, C2, .., Ck on the basis of the | |
529 % value of the categorical attribute, then the information needed to identify the class of an element of T | |
530 % is Info(T) = I(P), where P is the probability distribution of the partition (C1, C2, .., Ck): | |
531 % P = (|C1|/|T|, |C2|/|T|, ..., |Ck|/|T|) | |
532 % Here I(P) is defined as | |
533 % I(P) = -(p1*log(p1) + p2*log(p2) + .. + pn*log(pn)) | |
534 % | |
535 %(2) For continuous output (regression tree), we use least squares score (adapted from Leo Breiman's book "Classification and regression trees", page 231 | |
536 % The original support only binary split, we further extend it to permit multiple-child split | |
537 % | |
538 % Delta_R = R(T) - Sum for all childe nodes Ti (R(Ti)) | |
539 % Where R(Ti)= 1/N * Sum for all cases i in node Ti ((yi - avg_y(Ti))^2) | |
540 % here N is the number of all training cases for construct the regression tree | |
541 % avg_y(Ti) is the average value for output variable for the cases in node Ti | |
542 | |
543 function gain_score = compute_gain (fam_ev, node_sizes, node_types, T, info_T, attr_id, split_T, score_type, output_type) | |
544 % COMPUTE_GAIN Compute the score for the split of cases T using attribute attr_id | |
545 % gain_score = compute_gain (fam_ev, T, attr_id, node_size, method) | |
546 % | |
547 % fam_ev(i,j) is the value of attribute i in j-th training cases, the last row is for the class label (self_ev) | |
548 % T(i) is the index of i-th cases in current decision tree node, we need split it further | |
549 % attr_id is the index of current node considered for a split | |
550 % split_T{i} is the i_th subset in partition of cases T according to the value of attribute attr_id | |
551 % score_type if 0, is gain ratio, 1 is information gain (only apply to discrete output) | |
552 % node_size(i) the node size of i-th node in the family | |
553 % output_type: 0 means discrete output, 1 means continuous output. | |
554 gain_score=0; | |
555 % ***********for DISCRETE output******************************************************* | |
556 if (output_type == 0) | |
557 % compute Info(T) | |
558 total_cnt = size(T,2); | |
559 if (total_cnt==0) | |
560 return; | |
561 end; | |
562 %class_split_T = split_cases(fam_ev,node_sizes,node_types,T,size(fam_ev,1),0); %split cases according to class | |
563 %info_T = compute_info (fam_ev, T, class_split_T); | |
564 | |
565 % compute Info(X,T) | |
566 num_class = size(split_T,2); | |
567 subset_sizes = zeros(1,num_class); | |
568 info_ti = zeros(1,num_class); | |
569 for i=1:num_class | |
570 subset_sizes(i)=size(split_T{i},2); | |
571 if (subset_sizes(i)~=0) | |
572 class_split_Ti = split_cases(fam_ev,node_sizes,node_types,split_T{i},size(fam_ev,1),0); %split cases according to class | |
573 info_ti(i) = compute_info(fam_ev, split_T{i}, class_split_Ti); | |
574 end | |
575 end | |
576 ti_ratios = subset_sizes/total_cnt; %get the |Ti|/|T| | |
577 info_X_T = sum(ti_ratios.*info_ti); | |
578 | |
579 %get Gain(X,T) | |
580 gain_X_T = info_T - info_X_T; | |
581 | |
582 if (score_type == 1) %information gain | |
583 gain_score=gain_X_T; | |
584 return; | |
585 end | |
586 %compute the SplitInfo(X,T) //is this also for cts attr, only split into two subsets | |
587 splitinfo_T = compute_info (fam_ev, T, split_T); | |
588 if (splitinfo_T~=0) | |
589 gain_score = gain_X_T/splitinfo_T; | |
590 end | |
591 | |
592 % ************for continuous output************************************************** | |
593 else | |
594 N = size(fam_ev,2); | |
595 | |
596 % compute R(Ti) | |
597 num_class = size(split_T,2); | |
598 R_Ti = zeros(1,num_class); | |
599 for i=1:num_class | |
600 if (size(split_T{i},2)~=0) | |
601 cases_T = fam_ev(size(fam_ev,1),split_T{i}); | |
602 avg_y_T = mean(cases_T); | |
603 sqr_T = cases_T - avg_y_T; | |
604 R_Ti(i) = sum(sqr_T.*sqr_T)/N; % get R(Ti) = 1/N * SUM(y-avg_y)^2 | |
605 end | |
606 end | |
607 %delta_R = R(T) - SUM(R(Ti)) | |
608 gain_score = info_T - sum(R_Ti); | |
609 | |
610 end | |
611 | |
612 | |
613 % Definition of Info(Ti) | |
614 % If a set T of records is partitioned into disjoint exhaustive classes C1, C2, .., Ck on the basis of the | |
615 % value of the categorical attribute, then the information needed to identify the class of an element of T | |
616 % is Info(T) = I(P), where P is the probability distribution of the partition (C1, C2, .., Ck): | |
617 % P = (|C1|/|T|, |C2|/|T|, ..., |Ck|/|T|) | |
618 % Here I(P) is defined as | |
619 % I(P) = -(p1*log(p1) + p2*log(p2) + .. + pn*log(pn)) | |
620 function info = compute_info (fam_ev, T, split_T) | |
621 % COMPUTE_INFO compute the information for the split of T into split_T | |
622 % info = compute_info (fam_ev, T, split_T) | |
623 | |
624 total_cnt = size(T,2); | |
625 num_class = size(split_T,2); | |
626 subset_sizes = zeros(1,num_class); | |
627 probs = zeros(1,num_class); | |
628 log_probs = zeros(1,num_class); | |
629 for i=1:num_class | |
630 subset_sizes(i)=size(split_T{i},2); | |
631 end | |
632 | |
633 probs = subset_sizes/total_cnt; | |
634 %log_probs = log2(probs); % if probs(i)=0, the log2(probs(i)) will be Inf | |
635 for i=1:size(probs,2) | |
636 if (probs(i)~=0) | |
637 log_probs(i)=log2(probs(i)); | |
638 end | |
639 end | |
640 | |
641 info = sum(-(probs.*log_probs)); | |
642 |