To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @tree_CPD / evaluate_tree_performance.m @ 8:b5b38998ef3b

History | View | Annotate | Download (2.75 KB)

1
function [score,outputs] = evaluate(CPD, fam, data, ns, cnodes)
2
% Evaluate evaluate the performance of the classification/regression tree on given complete data
3
% score = evaluate(CPD, fam, data, ns, cnodes)
4
%
5
% fam(i) is the node id of the i-th node in the family of nodes, self node is the last one
6
% data(i,m) is the value of node i in case m (can be cell array).
7
% ns(i) is the node size for the i-th node in the whold bnet
8
% cnodes(i) is the node id for the i-th continuous node in the whole bnet
9
%  
10
% Output
11
% score is the classification accuracy (for classification) 
12
%          or mean square deviation (for regression)
13
%            here for every case we use the mean value at the tree leaf node as its predicted value
14
% outputs(i) is the predicted output value for case i
15
%
16
% Author: yimin.zhang@intel.com
17
% Last updated: Jan. 19, 2002
18

    
19

    
20
if iscell(data)
21
  local_data = cell2num(data(fam,:));
22
else
23
  local_data = data(fam, :);
24
end
25

    
26
%get local node sizes and node types
27
node_sizes = ns(fam);
28
node_types = zeros(1,size(ns,2)); %all nodes are disrete
29
node_types(cnodes)=1;
30
node_types=node_types(fam);
31

    
32
fam_size=size(fam,2);
33
output_type = node_types(fam_size);
34

    
35
num_cases=size(local_data,2);
36
total_error=0;
37

    
38
outputs=zeros(1,num_cases);
39
for i=1:num_cases
40
  %class one case using the tree
41
  cur_node=CPD.tree.root;  % at the root node of the tree
42
  while (1)
43
    if (CPD.tree.nodes(cur_node).is_leaf==1)
44
      if (output_type==0) %output is discrete
45
        %use the class with max probability as the output  
46
        [maxvalue,class_id]=max(CPD.tree.nodes(cur_node).probs);
47
        outputs(i)=class_id;
48
        if (class_id~=local_data(fam_size,i))
49
          total_error=total_error+1;
50
        end
51
      else   %output is continuous
52
        %use the mean as the value
53
        outputs(i)=CPD.tree.nodes(cur_node).mean;
54
        cur_deviation = CPD.tree.nodes(cur_node).mean-local_data(fam_size,i);
55
        total_error=total_error+cur_deviation*cur_deviation;
56
      end
57
      break;
58
    end
59
    cur_attr = CPD.tree.nodes(cur_node).split_id; 
60
    attr_val = local_data(cur_attr,i);
61
    if (node_types(cur_attr)==0)  %discrete attribute
62
        % goto the attr_val -th child
63
        cur_node = CPD.tree.nodes(cur_node).children(attr_val);
64
    else
65
        if (attr_val <= CPD.tree.nodes(cur_node).split_threshhold)
66
          cur_node = CPD.tree.nodes(cur_node).children(1);
67
        else
68
          cur_node = CPD.tree.nodes(cur_node).children(2);  
69
        end
70
    end
71
    if (cur_node > CPD.tree.num_node)
72
      fprintf('Fatal error: Tree structure corrupted.\n');
73
      return;
74
    end
75
  end
76
  %update the classification error number
77
end
78
if (output_type==0)
79
  score=1-total_error/num_cases;
80
else
81
  score=total_error/num_cases;
82
end