To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / general / Old / calc_mpe_bucket.m @ 8:b5b38998ef3b
History | View | Annotate | Download (4.26 KB)
| 1 |
function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over) |
|---|---|
| 2 |
% |
| 3 |
% PURPOSE: |
| 4 |
% CALC_MPE Computes the most probable explanation to the network nodes |
| 5 |
% given the evidence. |
| 6 |
% |
| 7 |
% [mpe, ll] = calc_mpe(engine, new_evidence, max_over) |
| 8 |
% |
| 9 |
% INPUT: |
| 10 |
% bnet - the bayesian network |
| 11 |
% new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)] |
| 12 |
% max_over - optional, if specified determines the variable elimination order [1:n] |
| 13 |
% |
| 14 |
% OUTPUT: |
| 15 |
% mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment) |
| 16 |
% ll - log assignment probability. |
| 17 |
% |
| 18 |
% Notes: |
| 19 |
% 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01 |
| 20 |
% 2. Only discrete potentials are supported at this time. |
| 21 |
% 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width. |
| 22 |
% 4. Implementation based on: |
| 23 |
% - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference", |
| 24 |
% UA1 96, pp. 211-219. |
| 25 |
|
| 26 |
|
| 27 |
ns = bnet.node_sizes; |
| 28 |
n = length(bnet.dag); |
| 29 |
evidence = cell(1,n); |
| 30 |
if (nargin<2) |
| 31 |
new_evidence = evidence; |
| 32 |
end |
| 33 |
|
| 34 |
onodes = find(~isemptycell(new_evidence)); % observed nodes |
| 35 |
hnodes = find(isemptycell(new_evidence)); % hidden nodes |
| 36 |
pot_type = determine_pot_type(bnet, onodes); |
| 37 |
|
| 38 |
if pot_type ~= 'd' |
| 39 |
error('only disrete potentials supported at this time')
|
| 40 |
end |
| 41 |
|
| 42 |
for i=1:n |
| 43 |
fam = family(bnet.dag, i); |
| 44 |
CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
|
| 45 |
end |
| 46 |
|
| 47 |
% handle observed nodes: set impossible cases' probability to zero |
| 48 |
% rather than prun matrix (this makes backtracking easier) |
| 49 |
|
| 50 |
for ii=onodes |
| 51 |
lIdx = 1:ns(ii); |
| 52 |
lIdx = setdiff(lIdx, new_evidence{ii});
|
| 53 |
|
| 54 |
sCPT=struct(CPT{ii}); % violate object privacy
|
| 55 |
|
| 56 |
sargs = ''; |
| 57 |
for jj=1:(length(sCPT.domain)-1) |
| 58 |
sargs = [sargs, ':,']; |
| 59 |
end |
| 60 |
for jj=lIdx |
| 61 |
eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
|
| 62 |
end |
| 63 |
CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);
|
| 64 |
end |
| 65 |
|
| 66 |
B = cell(1,n); |
| 67 |
for b=1:n |
| 68 |
B{b} = mk_initial_pot(pot_type, [], [], [], []);
|
| 69 |
end |
| 70 |
|
| 71 |
if (nargin<3) |
| 72 |
max_over = (1:n); |
| 73 |
end |
| 74 |
order = max_over; % no attempt to optimize this |
| 75 |
|
| 76 |
|
| 77 |
% Initialize the buckets with the CPDs assigned to them |
| 78 |
for i=1:n |
| 79 |
b = bucket_num(domain_pot(CPT{i}), order);
|
| 80 |
B{b} = multiply_pots(B{b}, CPT{i});
|
| 81 |
end |
| 82 |
|
| 83 |
% Do backward phase |
| 84 |
max_over = max_over(length(max_over):-1:1); % reverse |
| 85 |
for i=max_over(1:end-1) |
| 86 |
% max-ing over variable i which occurs in bucket j |
| 87 |
j = bucket_num(i, order); |
| 88 |
rest = mysetdiff(domain_pot(B{j}), i);
|
| 89 |
%temp = marginalize_pot_max(B{j}, rest);
|
| 90 |
temp = marginalize_pot(B{j}, rest, 1);
|
| 91 |
b = bucket_num(domain_pot(temp), order); |
| 92 |
% fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
|
| 93 |
sB=struct(B{b}); % violate object privacy
|
| 94 |
if ~isempty(sB.domain) |
| 95 |
B{b} = multiply_pots(B{b}, temp);
|
| 96 |
else |
| 97 |
B{b} = temp;
|
| 98 |
end |
| 99 |
end |
| 100 |
result = B{1};
|
| 101 |
marginal = pot_to_marginal(result); |
| 102 |
[prob, mpe] = max(marginal.T); |
| 103 |
|
| 104 |
% handle impossible cases |
| 105 |
if ~(prob>0) |
| 106 |
mpe = []; |
| 107 |
ll = -inf; |
| 108 |
%warning('evidence has zero probability')
|
| 109 |
return |
| 110 |
end |
| 111 |
|
| 112 |
ll = log(prob); |
| 113 |
|
| 114 |
% Do forward phase |
| 115 |
for ii=2:n |
| 116 |
marginal = pot_to_marginal(B{ii});
|
| 117 |
mpeidx = []; |
| 118 |
for jj=order(1:length(mpe)) |
| 119 |
assert(ismember(jj, marginal.domain)) %%% bug |
| 120 |
temp = find_equiv_posns(jj, marginal.domain); |
| 121 |
mpeidx = [mpeidx, temp] ; |
| 122 |
if isempty(temp) |
| 123 |
mpeidx = [mpeidx, Inf] ; |
| 124 |
end |
| 125 |
end |
| 126 |
[mpeidxsorted sortedtompe] = sort(mpeidx) ; |
| 127 |
|
| 128 |
% maximize the matrix obtained from assigning values from previous buckets. |
| 129 |
% this is done by building a string and using eval. |
| 130 |
|
| 131 |
kk=1; |
| 132 |
sargs = '(';
|
| 133 |
for jj=1:length(marginal.domain) |
| 134 |
if (jj~=1) |
| 135 |
sargs = [sargs, ',']; |
| 136 |
end |
| 137 |
if (mpeidxsorted(kk)==jj) |
| 138 |
sargs = [sargs, num2str(mpe(sortedtompe(kk)))]; |
| 139 |
if (kk<length(mpe)) |
| 140 |
kk = kk+1 ; |
| 141 |
end |
| 142 |
else |
| 143 |
sargs = [sargs, ':']; |
| 144 |
end |
| 145 |
end |
| 146 |
sargs = [sargs, ')'] ; |
| 147 |
eval(['[val, loc] = max(marginal.T', sargs, ');']) |
| 148 |
mpe = [mpe loc]; |
| 149 |
end |
| 150 |
[I,J] = sort(order); |
| 151 |
mpe = mpe(J); |
| 152 |
|
| 153 |
|
| 154 |
|
| 155 |
%%%%%%%%% |
| 156 |
|
| 157 |
function b = bucket_num(domain, order) |
| 158 |
|
| 159 |
b = max(find_equiv_posns(domain, order)); |
| 160 |
|