view toolboxes/FullBNT-1.0.7/bnt/CPDs/@tree_CPD/evaluate_tree_performance.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [score,outputs] = evaluate(CPD, fam, data, ns, cnodes)
% Evaluate evaluate the performance of the classification/regression tree on given complete data
% score = evaluate(CPD, fam, data, ns, cnodes)
%
% fam(i) is the node id of the i-th node in the family of nodes, self node is the last one
% data(i,m) is the value of node i in case m (can be cell array).
% ns(i) is the node size for the i-th node in the whold bnet
% cnodes(i) is the node id for the i-th continuous node in the whole bnet
%  
% Output
% score is the classification accuracy (for classification) 
%          or mean square deviation (for regression)
%            here for every case we use the mean value at the tree leaf node as its predicted value
% outputs(i) is the predicted output value for case i
%
% Author: yimin.zhang@intel.com
% Last updated: Jan. 19, 2002


if iscell(data)
  local_data = cell2num(data(fam,:));
else
  local_data = data(fam, :);
end

%get local node sizes and node types
node_sizes = ns(fam);
node_types = zeros(1,size(ns,2)); %all nodes are disrete
node_types(cnodes)=1;
node_types=node_types(fam);

fam_size=size(fam,2);
output_type = node_types(fam_size);

num_cases=size(local_data,2);
total_error=0;

outputs=zeros(1,num_cases);
for i=1:num_cases
  %class one case using the tree
  cur_node=CPD.tree.root;  % at the root node of the tree
  while (1)
    if (CPD.tree.nodes(cur_node).is_leaf==1)
      if (output_type==0) %output is discrete
        %use the class with max probability as the output  
        [maxvalue,class_id]=max(CPD.tree.nodes(cur_node).probs);
        outputs(i)=class_id;
        if (class_id~=local_data(fam_size,i))
          total_error=total_error+1;
        end
      else   %output is continuous
        %use the mean as the value
        outputs(i)=CPD.tree.nodes(cur_node).mean;
        cur_deviation = CPD.tree.nodes(cur_node).mean-local_data(fam_size,i);
        total_error=total_error+cur_deviation*cur_deviation;
      end
      break;
    end
    cur_attr = CPD.tree.nodes(cur_node).split_id; 
    attr_val = local_data(cur_attr,i);
    if (node_types(cur_attr)==0)  %discrete attribute
        % goto the attr_val -th child
        cur_node = CPD.tree.nodes(cur_node).children(attr_val);
    else
        if (attr_val <= CPD.tree.nodes(cur_node).split_threshhold)
          cur_node = CPD.tree.nodes(cur_node).children(1);
        else
          cur_node = CPD.tree.nodes(cur_node).children(2);  
        end
    end
    if (cur_node > CPD.tree.num_node)
      fprintf('Fatal error: Tree structure corrupted.\n');
      return;
    end
  end
  %update the classification error number
end
if (output_type==0)
  score=1-total_error/num_cases;
else
  score=total_error/num_cases;
end