comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@tree_CPD/evaluate_tree_performance.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function [score,outputs] = evaluate(CPD, fam, data, ns, cnodes)
2 % Evaluate evaluate the performance of the classification/regression tree on given complete data
3 % score = evaluate(CPD, fam, data, ns, cnodes)
4 %
5 % fam(i) is the node id of the i-th node in the family of nodes, self node is the last one
6 % data(i,m) is the value of node i in case m (can be cell array).
7 % ns(i) is the node size for the i-th node in the whold bnet
8 % cnodes(i) is the node id for the i-th continuous node in the whole bnet
9 %
10 % Output
11 % score is the classification accuracy (for classification)
12 % or mean square deviation (for regression)
13 % here for every case we use the mean value at the tree leaf node as its predicted value
14 % outputs(i) is the predicted output value for case i
15 %
16 % Author: yimin.zhang@intel.com
17 % Last updated: Jan. 19, 2002
18
19
20 if iscell(data)
21 local_data = cell2num(data(fam,:));
22 else
23 local_data = data(fam, :);
24 end
25
26 %get local node sizes and node types
27 node_sizes = ns(fam);
28 node_types = zeros(1,size(ns,2)); %all nodes are disrete
29 node_types(cnodes)=1;
30 node_types=node_types(fam);
31
32 fam_size=size(fam,2);
33 output_type = node_types(fam_size);
34
35 num_cases=size(local_data,2);
36 total_error=0;
37
38 outputs=zeros(1,num_cases);
39 for i=1:num_cases
40 %class one case using the tree
41 cur_node=CPD.tree.root; % at the root node of the tree
42 while (1)
43 if (CPD.tree.nodes(cur_node).is_leaf==1)
44 if (output_type==0) %output is discrete
45 %use the class with max probability as the output
46 [maxvalue,class_id]=max(CPD.tree.nodes(cur_node).probs);
47 outputs(i)=class_id;
48 if (class_id~=local_data(fam_size,i))
49 total_error=total_error+1;
50 end
51 else %output is continuous
52 %use the mean as the value
53 outputs(i)=CPD.tree.nodes(cur_node).mean;
54 cur_deviation = CPD.tree.nodes(cur_node).mean-local_data(fam_size,i);
55 total_error=total_error+cur_deviation*cur_deviation;
56 end
57 break;
58 end
59 cur_attr = CPD.tree.nodes(cur_node).split_id;
60 attr_val = local_data(cur_attr,i);
61 if (node_types(cur_attr)==0) %discrete attribute
62 % goto the attr_val -th child
63 cur_node = CPD.tree.nodes(cur_node).children(attr_val);
64 else
65 if (attr_val <= CPD.tree.nodes(cur_node).split_threshhold)
66 cur_node = CPD.tree.nodes(cur_node).children(1);
67 else
68 cur_node = CPD.tree.nodes(cur_node).children(2);
69 end
70 end
71 if (cur_node > CPD.tree.num_node)
72 fprintf('Fatal error: Tree structure corrupted.\n');
73 return;
74 end
75 end
76 %update the classification error number
77 end
78 if (output_type==0)
79 score=1-total_error/num_cases;
80 else
81 score=total_error/num_cases;
82 end