wolffd@0
|
1 function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over)
|
wolffd@0
|
2 %
|
wolffd@0
|
3 % PURPOSE:
|
wolffd@0
|
4 % CALC_MPE Computes the most probable explanation to the network nodes
|
wolffd@0
|
5 % given the evidence.
|
wolffd@0
|
6 %
|
wolffd@0
|
7 % [mpe, ll] = calc_mpe(engine, new_evidence, max_over)
|
wolffd@0
|
8 %
|
wolffd@0
|
9 % INPUT:
|
wolffd@0
|
10 % bnet - the bayesian network
|
wolffd@0
|
11 % new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)]
|
wolffd@0
|
12 % max_over - optional, if specified determines the variable elimination order [1:n]
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % OUTPUT:
|
wolffd@0
|
15 % mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment)
|
wolffd@0
|
16 % ll - log assignment probability.
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % Notes:
|
wolffd@0
|
19 % 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01
|
wolffd@0
|
20 % 2. Only discrete potentials are supported at this time.
|
wolffd@0
|
21 % 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width.
|
wolffd@0
|
22 % 4. Implementation based on:
|
wolffd@0
|
23 % - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference",
|
wolffd@0
|
24 % UA1 96, pp. 211-219.
|
wolffd@0
|
25
|
wolffd@0
|
26
|
wolffd@0
|
27 ns = bnet.node_sizes;
|
wolffd@0
|
28 n = length(bnet.dag);
|
wolffd@0
|
29 evidence = cell(1,n);
|
wolffd@0
|
30 if (nargin<2)
|
wolffd@0
|
31 new_evidence = evidence;
|
wolffd@0
|
32 end
|
wolffd@0
|
33
|
wolffd@0
|
34 onodes = find(~isemptycell(new_evidence)); % observed nodes
|
wolffd@0
|
35 hnodes = find(isemptycell(new_evidence)); % hidden nodes
|
wolffd@0
|
36 pot_type = determine_pot_type(bnet, onodes);
|
wolffd@0
|
37
|
wolffd@0
|
38 if pot_type ~= 'd'
|
wolffd@0
|
39 error('only disrete potentials supported at this time')
|
wolffd@0
|
40 end
|
wolffd@0
|
41
|
wolffd@0
|
42 for i=1:n
|
wolffd@0
|
43 fam = family(bnet.dag, i);
|
wolffd@0
|
44 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
|
wolffd@0
|
45 end
|
wolffd@0
|
46
|
wolffd@0
|
47 % handle observed nodes: set impossible cases' probability to zero
|
wolffd@0
|
48 % rather than prun matrix (this makes backtracking easier)
|
wolffd@0
|
49
|
wolffd@0
|
50 for ii=onodes
|
wolffd@0
|
51 lIdx = 1:ns(ii);
|
wolffd@0
|
52 lIdx = setdiff(lIdx, new_evidence{ii});
|
wolffd@0
|
53
|
wolffd@0
|
54 sCPT=struct(CPT{ii}); % violate object privacy
|
wolffd@0
|
55
|
wolffd@0
|
56 sargs = '';
|
wolffd@0
|
57 for jj=1:(length(sCPT.domain)-1)
|
wolffd@0
|
58 sargs = [sargs, ':,'];
|
wolffd@0
|
59 end
|
wolffd@0
|
60 for jj=lIdx
|
wolffd@0
|
61 eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
|
wolffd@0
|
62 end
|
wolffd@0
|
63 CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);
|
wolffd@0
|
64 end
|
wolffd@0
|
65
|
wolffd@0
|
66 B = cell(1,n);
|
wolffd@0
|
67 for b=1:n
|
wolffd@0
|
68 B{b} = mk_initial_pot(pot_type, [], [], [], []);
|
wolffd@0
|
69 end
|
wolffd@0
|
70
|
wolffd@0
|
71 if (nargin<3)
|
wolffd@0
|
72 max_over = (1:n);
|
wolffd@0
|
73 end
|
wolffd@0
|
74 order = max_over; % no attempt to optimize this
|
wolffd@0
|
75
|
wolffd@0
|
76
|
wolffd@0
|
77 % Initialize the buckets with the CPDs assigned to them
|
wolffd@0
|
78 for i=1:n
|
wolffd@0
|
79 b = bucket_num(domain_pot(CPT{i}), order);
|
wolffd@0
|
80 B{b} = multiply_pots(B{b}, CPT{i});
|
wolffd@0
|
81 end
|
wolffd@0
|
82
|
wolffd@0
|
83 % Do backward phase
|
wolffd@0
|
84 max_over = max_over(length(max_over):-1:1); % reverse
|
wolffd@0
|
85 for i=max_over(1:end-1)
|
wolffd@0
|
86 % max-ing over variable i which occurs in bucket j
|
wolffd@0
|
87 j = bucket_num(i, order);
|
wolffd@0
|
88 rest = mysetdiff(domain_pot(B{j}), i);
|
wolffd@0
|
89 %temp = marginalize_pot_max(B{j}, rest);
|
wolffd@0
|
90 temp = marginalize_pot(B{j}, rest, 1);
|
wolffd@0
|
91 b = bucket_num(domain_pot(temp), order);
|
wolffd@0
|
92 % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
|
wolffd@0
|
93 sB=struct(B{b}); % violate object privacy
|
wolffd@0
|
94 if ~isempty(sB.domain)
|
wolffd@0
|
95 B{b} = multiply_pots(B{b}, temp);
|
wolffd@0
|
96 else
|
wolffd@0
|
97 B{b} = temp;
|
wolffd@0
|
98 end
|
wolffd@0
|
99 end
|
wolffd@0
|
100 result = B{1};
|
wolffd@0
|
101 marginal = pot_to_marginal(result);
|
wolffd@0
|
102 [prob, mpe] = max(marginal.T);
|
wolffd@0
|
103
|
wolffd@0
|
104 % handle impossible cases
|
wolffd@0
|
105 if ~(prob>0)
|
wolffd@0
|
106 mpe = [];
|
wolffd@0
|
107 ll = -inf;
|
wolffd@0
|
108 %warning('evidence has zero probability')
|
wolffd@0
|
109 return
|
wolffd@0
|
110 end
|
wolffd@0
|
111
|
wolffd@0
|
112 ll = log(prob);
|
wolffd@0
|
113
|
wolffd@0
|
114 % Do forward phase
|
wolffd@0
|
115 for ii=2:n
|
wolffd@0
|
116 marginal = pot_to_marginal(B{ii});
|
wolffd@0
|
117 mpeidx = [];
|
wolffd@0
|
118 for jj=order(1:length(mpe))
|
wolffd@0
|
119 assert(ismember(jj, marginal.domain)) %%% bug
|
wolffd@0
|
120 temp = find_equiv_posns(jj, marginal.domain);
|
wolffd@0
|
121 mpeidx = [mpeidx, temp] ;
|
wolffd@0
|
122 if isempty(temp)
|
wolffd@0
|
123 mpeidx = [mpeidx, Inf] ;
|
wolffd@0
|
124 end
|
wolffd@0
|
125 end
|
wolffd@0
|
126 [mpeidxsorted sortedtompe] = sort(mpeidx) ;
|
wolffd@0
|
127
|
wolffd@0
|
128 % maximize the matrix obtained from assigning values from previous buckets.
|
wolffd@0
|
129 % this is done by building a string and using eval.
|
wolffd@0
|
130
|
wolffd@0
|
131 kk=1;
|
wolffd@0
|
132 sargs = '(';
|
wolffd@0
|
133 for jj=1:length(marginal.domain)
|
wolffd@0
|
134 if (jj~=1)
|
wolffd@0
|
135 sargs = [sargs, ','];
|
wolffd@0
|
136 end
|
wolffd@0
|
137 if (mpeidxsorted(kk)==jj)
|
wolffd@0
|
138 sargs = [sargs, num2str(mpe(sortedtompe(kk)))];
|
wolffd@0
|
139 if (kk<length(mpe))
|
wolffd@0
|
140 kk = kk+1 ;
|
wolffd@0
|
141 end
|
wolffd@0
|
142 else
|
wolffd@0
|
143 sargs = [sargs, ':'];
|
wolffd@0
|
144 end
|
wolffd@0
|
145 end
|
wolffd@0
|
146 sargs = [sargs, ')'] ;
|
wolffd@0
|
147 eval(['[val, loc] = max(marginal.T', sargs, ');'])
|
wolffd@0
|
148 mpe = [mpe loc];
|
wolffd@0
|
149 end
|
wolffd@0
|
150 [I,J] = sort(order);
|
wolffd@0
|
151 mpe = mpe(J);
|
wolffd@0
|
152
|
wolffd@0
|
153
|
wolffd@0
|
154
|
wolffd@0
|
155 %%%%%%%%%
|
wolffd@0
|
156
|
wolffd@0
|
157 function b = bucket_num(domain, order)
|
wolffd@0
|
158
|
wolffd@0
|
159 b = max(find_equiv_posns(domain, order));
|
wolffd@0
|
160
|