Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/CPDs/@tree_CPD/evaluate_tree_performance.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/CPDs/@tree_CPD/evaluate_tree_performance.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,82 @@ +function [score,outputs] = evaluate(CPD, fam, data, ns, cnodes) +% Evaluate evaluate the performance of the classification/regression tree on given complete data +% score = evaluate(CPD, fam, data, ns, cnodes) +% +% fam(i) is the node id of the i-th node in the family of nodes, self node is the last one +% data(i,m) is the value of node i in case m (can be cell array). +% ns(i) is the node size for the i-th node in the whold bnet +% cnodes(i) is the node id for the i-th continuous node in the whole bnet +% +% Output +% score is the classification accuracy (for classification) +% or mean square deviation (for regression) +% here for every case we use the mean value at the tree leaf node as its predicted value +% outputs(i) is the predicted output value for case i +% +% Author: yimin.zhang@intel.com +% Last updated: Jan. 19, 2002 + + +if iscell(data) + local_data = cell2num(data(fam,:)); +else + local_data = data(fam, :); +end + +%get local node sizes and node types +node_sizes = ns(fam); +node_types = zeros(1,size(ns,2)); %all nodes are disrete +node_types(cnodes)=1; +node_types=node_types(fam); + +fam_size=size(fam,2); +output_type = node_types(fam_size); + +num_cases=size(local_data,2); +total_error=0; + +outputs=zeros(1,num_cases); +for i=1:num_cases + %class one case using the tree + cur_node=CPD.tree.root; % at the root node of the tree + while (1) + if (CPD.tree.nodes(cur_node).is_leaf==1) + if (output_type==0) %output is discrete + %use the class with max probability as the output + [maxvalue,class_id]=max(CPD.tree.nodes(cur_node).probs); + outputs(i)=class_id; + if (class_id~=local_data(fam_size,i)) + total_error=total_error+1; + end + else %output is continuous + %use the mean as the value + outputs(i)=CPD.tree.nodes(cur_node).mean; + cur_deviation = CPD.tree.nodes(cur_node).mean-local_data(fam_size,i); + total_error=total_error+cur_deviation*cur_deviation; + end + break; + end + cur_attr = CPD.tree.nodes(cur_node).split_id; + attr_val = local_data(cur_attr,i); + if (node_types(cur_attr)==0) %discrete attribute + % goto the attr_val -th child + cur_node = CPD.tree.nodes(cur_node).children(attr_val); + else + if (attr_val <= CPD.tree.nodes(cur_node).split_threshhold) + cur_node = CPD.tree.nodes(cur_node).children(1); + else + cur_node = CPD.tree.nodes(cur_node).children(2); + end + end + if (cur_node > CPD.tree.num_node) + fprintf('Fatal error: Tree structure corrupted.\n'); + return; + end + end + %update the classification error number +end +if (output_type==0) + score=1-total_error/num_cases; +else + score=total_error/num_cases; +end