comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@var_elim_inf_engine/find_mpe.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 function mpe = find_mpe(engine, new_evidence, max_over)
2 % FIND_MPE Find the most probable explanation of the data (assignment to the hidden nodes)
3 % function mpe = find_mpe(engine, evidence, order)
4 %
5 % PURPOSE:
6 % CALC_MPE Computes the most probable explanation to the network nodes
7 % given the evidence.
8 %
9 % [mpe, ll] = calc_mpe(engine, new_evidence, max_over)
10 %
11 % INPUT:
12 % bnet - the bayesian network
13 % new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)]
14 % max_over - optional, if specified determines the variable elimination order [1:n]
15 %
16 % OUTPUT:
17 % mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment)
18 % ll - log assignment probability.
19 %
20 % Notes:
21 % 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01
22 % 2. Only discrete potentials are supported at this time.
23 % 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width.
24 % 4. Implementation based on:
25 % - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference",
26 % UA1 96, pp. 211-219.
27
28 bnet = bnet_from_engine(engine);
29 ns = bnet.node_sizes;
30 n = length(bnet.dag);
31 evidence = cell(1,n);
32 if (nargin<2)
33 new_evidence = evidence;
34 end
35
36 onodes = find(~isemptycell(new_evidence)); % observed nodes
37 hnodes = find(isemptycell(new_evidence)); % hidden nodes
38 pot_type = determine_pot_type(bnet, onodes);
39
40 if pot_type ~= 'd'
41 error('only disrete potentials supported at this time')
42 end
43
44 for i=1:n
45 fam = family(bnet.dag, i);
46 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
47 end
48
49 % handle observed nodes: set impossible cases' probability to zero
50 % rather than prun matrix (this makes backtracking easier)
51
52 for ii=onodes
53 lIdx = 1:ns(ii);
54 lIdx = setdiff(lIdx, new_evidence{ii});
55
56 sCPT=struct(CPT{ii}); % violate object privacy
57
58 sargs = '';
59 for jj=1:(length(sCPT.domain)-1)
60 sargs = [sargs, ':,'];
61 end
62 for jj=lIdx
63 eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
64 end
65 CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);
66 end
67
68 B = cell(1,n);
69 for b=1:n
70 B{b} = mk_initial_pot(pot_type, [], [], [], []);
71 end
72
73 if (nargin<3)
74 max_over = (1:n);
75 end
76 order = max_over; % no attempt to optimize this
77
78
79 % Initialize the buckets with the CPDs assigned to them
80 for i=1:n
81 b = bucket_num(domain_pot(CPT{i}), order);
82 B{b} = multiply_pots(B{b}, CPT{i});
83 end
84
85 % Do backward phase
86 max_over = max_over(length(max_over):-1:1); % reverse
87 maximize = 1;
88 for i=max_over(1:end-1)
89 % max-ing over variable i which occurs in bucket j
90 j = bucket_num(i, order);
91 rest = mysetdiff(domain_pot(B{j}), i);
92 %temp = marginalize_pot_max(B{j}, rest);
93 temp = marginalize_pot(B{j}, rest, maximize);
94 b = bucket_num(domain_pot(temp), order);
95 % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
96 sB=struct(B{b}); % violate object privacy
97 if ~isempty(sB.domain)
98 B{b} = multiply_pots(B{b}, temp);
99 else
100 B{b} = temp;
101 end
102 end
103 result = B{1};
104 marginal = pot_to_marginal(result);
105 [prob, mpe] = max(marginal.T);
106
107 % handle impossible cases
108 if ~(prob>0)
109 mpe = [];
110 ll = -inf;
111 %warning('evidence has zero probability')
112 return
113 end
114
115 ll = log(prob);
116
117 % Do forward phase
118 for ii=2:n
119 marginal = pot_to_marginal(B{ii});
120 mpeidx = [];
121 for jj=order(1:length(mpe))
122 %assert(ismember(jj, marginal.domain)) %%% bug
123 temp = find_equiv_posns(jj, marginal.domain);
124 mpeidx = [mpeidx, temp] ;
125 if isempty(temp)
126 mpeidx = [mpeidx, Inf] ;
127 end
128 end
129 [mpeidxsorted sortedtompe] = sort(mpeidx) ;
130
131 % maximize the matrix obtained from assigning values from previous buckets.
132 % this is done by building a string and using eval.
133
134 kk=1;
135 sargs = '(';
136 for jj=1:length(marginal.domain)
137 if (jj~=1)
138 sargs = [sargs, ','];
139 end
140 if (mpeidxsorted(kk)==jj)
141 sargs = [sargs, num2str(mpe(sortedtompe(kk)))];
142 if (kk<length(mpe))
143 kk = kk+1 ;
144 end
145 else
146 sargs = [sargs, ':'];
147 end
148 end
149 sargs = [sargs, ')'] ;
150 eval(['[val, loc] = max(marginal.T', sargs, ');'])
151 mpe = [mpe loc];
152 end
153 [I,J] = sort(order);
154 mpe = mpe(J);
155
156 mpe = num2cell(mpe);
157
158 %%%%%%%%%
159
160 function b = bucket_num(domain, order)
161
162 b = max(find_equiv_posns(domain, order));
163