To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @tree_CPD / evaluate_tree_performance.m @ 8:b5b38998ef3b
History | View | Annotate | Download (2.75 KB)
| 1 |
function [score,outputs] = evaluate(CPD, fam, data, ns, cnodes) |
|---|---|
| 2 |
% Evaluate evaluate the performance of the classification/regression tree on given complete data |
| 3 |
% score = evaluate(CPD, fam, data, ns, cnodes) |
| 4 |
% |
| 5 |
% fam(i) is the node id of the i-th node in the family of nodes, self node is the last one |
| 6 |
% data(i,m) is the value of node i in case m (can be cell array). |
| 7 |
% ns(i) is the node size for the i-th node in the whold bnet |
| 8 |
% cnodes(i) is the node id for the i-th continuous node in the whole bnet |
| 9 |
% |
| 10 |
% Output |
| 11 |
% score is the classification accuracy (for classification) |
| 12 |
% or mean square deviation (for regression) |
| 13 |
% here for every case we use the mean value at the tree leaf node as its predicted value |
| 14 |
% outputs(i) is the predicted output value for case i |
| 15 |
% |
| 16 |
% Author: yimin.zhang@intel.com |
| 17 |
% Last updated: Jan. 19, 2002 |
| 18 |
|
| 19 |
|
| 20 |
if iscell(data) |
| 21 |
local_data = cell2num(data(fam,:)); |
| 22 |
else |
| 23 |
local_data = data(fam, :); |
| 24 |
end |
| 25 |
|
| 26 |
%get local node sizes and node types |
| 27 |
node_sizes = ns(fam); |
| 28 |
node_types = zeros(1,size(ns,2)); %all nodes are disrete |
| 29 |
node_types(cnodes)=1; |
| 30 |
node_types=node_types(fam); |
| 31 |
|
| 32 |
fam_size=size(fam,2); |
| 33 |
output_type = node_types(fam_size); |
| 34 |
|
| 35 |
num_cases=size(local_data,2); |
| 36 |
total_error=0; |
| 37 |
|
| 38 |
outputs=zeros(1,num_cases); |
| 39 |
for i=1:num_cases |
| 40 |
%class one case using the tree |
| 41 |
cur_node=CPD.tree.root; % at the root node of the tree |
| 42 |
while (1) |
| 43 |
if (CPD.tree.nodes(cur_node).is_leaf==1) |
| 44 |
if (output_type==0) %output is discrete |
| 45 |
%use the class with max probability as the output |
| 46 |
[maxvalue,class_id]=max(CPD.tree.nodes(cur_node).probs); |
| 47 |
outputs(i)=class_id; |
| 48 |
if (class_id~=local_data(fam_size,i)) |
| 49 |
total_error=total_error+1; |
| 50 |
end |
| 51 |
else %output is continuous |
| 52 |
%use the mean as the value |
| 53 |
outputs(i)=CPD.tree.nodes(cur_node).mean; |
| 54 |
cur_deviation = CPD.tree.nodes(cur_node).mean-local_data(fam_size,i); |
| 55 |
total_error=total_error+cur_deviation*cur_deviation; |
| 56 |
end |
| 57 |
break; |
| 58 |
end |
| 59 |
cur_attr = CPD.tree.nodes(cur_node).split_id; |
| 60 |
attr_val = local_data(cur_attr,i); |
| 61 |
if (node_types(cur_attr)==0) %discrete attribute |
| 62 |
% goto the attr_val -th child |
| 63 |
cur_node = CPD.tree.nodes(cur_node).children(attr_val); |
| 64 |
else |
| 65 |
if (attr_val <= CPD.tree.nodes(cur_node).split_threshhold) |
| 66 |
cur_node = CPD.tree.nodes(cur_node).children(1); |
| 67 |
else |
| 68 |
cur_node = CPD.tree.nodes(cur_node).children(2); |
| 69 |
end |
| 70 |
end |
| 71 |
if (cur_node > CPD.tree.num_node) |
| 72 |
fprintf('Fatal error: Tree structure corrupted.\n');
|
| 73 |
return; |
| 74 |
end |
| 75 |
end |
| 76 |
%update the classification error number |
| 77 |
end |
| 78 |
if (output_type==0) |
| 79 |
score=1-total_error/num_cases; |
| 80 |
else |
| 81 |
score=total_error/num_cases; |
| 82 |
end |