To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / general / Old / calc_mpe_bucket.m @ 8:b5b38998ef3b

History | View | Annotate | Download (4.26 KB)

1
function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over)
2
%
3
% PURPOSE:
4
%       CALC_MPE Computes the most probable explanation to the network nodes
5
%       given the evidence.
6
%       
7
%       [mpe, ll] = calc_mpe(engine, new_evidence, max_over)
8
%
9
% INPUT:
10
%       bnet  - the bayesian network
11
%       new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)]
12
%       max_over - optional, if specified determines the variable elimination order [1:n]
13
%
14
% OUTPUT:
15
%       mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment)
16
%       ll - log assignment probability.
17
%
18
% Notes:
19
% 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01
20
% 2. Only discrete potentials are supported at this time.
21
% 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width.
22
% 4. Implementation based on:
23
%  - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference", 
24
%                 UA1 96, pp. 211-219.
25

    
26

    
27
ns = bnet.node_sizes;
28
n = length(bnet.dag);
29
evidence = cell(1,n);
30
if (nargin<2)
31
    new_evidence = evidence;
32
end
33

    
34
onodes = find(~isemptycell(new_evidence));  % observed nodes
35
hnodes = find(isemptycell(new_evidence));  % hidden nodes
36
pot_type = determine_pot_type(bnet, onodes);
37

    
38
if pot_type ~= 'd'
39
  error('only disrete potentials supported at this time')    
40
end
41

    
42
for i=1:n
43
  fam = family(bnet.dag, i);
44
  CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);        
45
end 
46

    
47
% handle observed nodes: set impossible cases' probability to zero
48
% rather than prun matrix (this makes backtracking easier)
49

    
50
for ii=onodes
51
  lIdx = 1:ns(ii);
52
  lIdx = setdiff(lIdx, new_evidence{ii});
53
  
54
  sCPT=struct(CPT{ii});  % violate object privacy
55
  
56
  sargs = '';
57
  for jj=1:(length(sCPT.domain)-1)
58
    sargs = [sargs, ':,']; 
59
  end        
60
  for jj=lIdx
61
    eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
62
  end
63
  CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);        
64
end
65

    
66
B = cell(1,n); 
67
for b=1:n
68
  B{b} = mk_initial_pot(pot_type, [], [], [], []);
69
end
70

    
71
if (nargin<3)
72
  max_over = (1:n);
73
end   
74
order = max_over; % no attempt to optimize this
75

    
76

    
77
% Initialize the buckets with the CPDs assigned to them
78
for i=1:n
79
  b = bucket_num(domain_pot(CPT{i}), order);
80
  B{b} = multiply_pots(B{b}, CPT{i});
81
end
82

    
83
% Do backward phase
84
max_over = max_over(length(max_over):-1:1); % reverse
85
for i=max_over(1:end-1)        
86
  % max-ing over variable i which occurs in bucket j
87
  j = bucket_num(i, order);
88
  rest = mysetdiff(domain_pot(B{j}), i);
89
  %temp = marginalize_pot_max(B{j}, rest);
90
  temp = marginalize_pot(B{j}, rest, 1);
91
  b = bucket_num(domain_pot(temp), order);
92
  %        fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
93
  sB=struct(B{b});  % violate object privacy
94
  if ~isempty(sB.domain)
95
    B{b} = multiply_pots(B{b}, temp);
96
  else
97
    B{b} = temp;
98
  end
99
end
100
result = B{1};
101
marginal = pot_to_marginal(result);
102
[prob, mpe] = max(marginal.T);
103

    
104
% handle impossible cases
105
if ~(prob>0)
106
  mpe = [];    
107
  ll = -inf;
108
  %warning('evidence has zero probability')
109
  return
110
end
111

    
112
ll = log(prob);
113

    
114
% Do forward phase    
115
for ii=2:n
116
  marginal = pot_to_marginal(B{ii});
117
  mpeidx = [];
118
  for jj=order(1:length(mpe))
119
    assert(ismember(jj, marginal.domain)) %%% bug
120
    temp = find_equiv_posns(jj, marginal.domain);
121
    mpeidx = [mpeidx, temp] ;
122
    if isempty(temp)
123
      mpeidx = [mpeidx, Inf] ;
124
    end
125
  end
126
  [mpeidxsorted sortedtompe] = sort(mpeidx) ;
127
  
128
  % maximize the matrix obtained from assigning values from previous buckets.
129
  % this is done by building a string and using eval.
130
  
131
  kk=1;
132
  sargs = '(';
133
  for jj=1:length(marginal.domain)
134
    if (jj~=1)
135
      sargs = [sargs, ','];
136
    end
137
    if (mpeidxsorted(kk)==jj)
138
      sargs = [sargs, num2str(mpe(sortedtompe(kk)))];
139
      if (kk<length(mpe))
140
	kk = kk+1 ;
141
      end
142
    else
143
      sargs = [sargs, ':'];
144
    end
145
  end
146
  sargs = [sargs, ')'] ;   
147
  eval(['[val, loc] = max(marginal.T', sargs, ');'])        
148
  mpe = [mpe loc];
149
end     
150
[I,J] = sort(order);
151
mpe = mpe(J);
152

    
153

    
154

    
155
%%%%%%%%%
156

    
157
function b = bucket_num(domain, order)
158

    
159
b = max(find_equiv_posns(domain, order));
160