Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/general/Old/calc_mpe_bucket.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over) | |
2 % | |
3 % PURPOSE: | |
4 % CALC_MPE Computes the most probable explanation to the network nodes | |
5 % given the evidence. | |
6 % | |
7 % [mpe, ll] = calc_mpe(engine, new_evidence, max_over) | |
8 % | |
9 % INPUT: | |
10 % bnet - the bayesian network | |
11 % new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)] | |
12 % max_over - optional, if specified determines the variable elimination order [1:n] | |
13 % | |
14 % OUTPUT: | |
15 % mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment) | |
16 % ll - log assignment probability. | |
17 % | |
18 % Notes: | |
19 % 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01 | |
20 % 2. Only discrete potentials are supported at this time. | |
21 % 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width. | |
22 % 4. Implementation based on: | |
23 % - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference", | |
24 % UA1 96, pp. 211-219. | |
25 | |
26 | |
27 ns = bnet.node_sizes; | |
28 n = length(bnet.dag); | |
29 evidence = cell(1,n); | |
30 if (nargin<2) | |
31 new_evidence = evidence; | |
32 end | |
33 | |
34 onodes = find(~isemptycell(new_evidence)); % observed nodes | |
35 hnodes = find(isemptycell(new_evidence)); % hidden nodes | |
36 pot_type = determine_pot_type(bnet, onodes); | |
37 | |
38 if pot_type ~= 'd' | |
39 error('only disrete potentials supported at this time') | |
40 end | |
41 | |
42 for i=1:n | |
43 fam = family(bnet.dag, i); | |
44 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence); | |
45 end | |
46 | |
47 % handle observed nodes: set impossible cases' probability to zero | |
48 % rather than prun matrix (this makes backtracking easier) | |
49 | |
50 for ii=onodes | |
51 lIdx = 1:ns(ii); | |
52 lIdx = setdiff(lIdx, new_evidence{ii}); | |
53 | |
54 sCPT=struct(CPT{ii}); % violate object privacy | |
55 | |
56 sargs = ''; | |
57 for jj=1:(length(sCPT.domain)-1) | |
58 sargs = [sargs, ':,']; | |
59 end | |
60 for jj=lIdx | |
61 eval(['sCPT.T(', sargs, num2str(jj), ')=0;']); | |
62 end | |
63 CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T); | |
64 end | |
65 | |
66 B = cell(1,n); | |
67 for b=1:n | |
68 B{b} = mk_initial_pot(pot_type, [], [], [], []); | |
69 end | |
70 | |
71 if (nargin<3) | |
72 max_over = (1:n); | |
73 end | |
74 order = max_over; % no attempt to optimize this | |
75 | |
76 | |
77 % Initialize the buckets with the CPDs assigned to them | |
78 for i=1:n | |
79 b = bucket_num(domain_pot(CPT{i}), order); | |
80 B{b} = multiply_pots(B{b}, CPT{i}); | |
81 end | |
82 | |
83 % Do backward phase | |
84 max_over = max_over(length(max_over):-1:1); % reverse | |
85 for i=max_over(1:end-1) | |
86 % max-ing over variable i which occurs in bucket j | |
87 j = bucket_num(i, order); | |
88 rest = mysetdiff(domain_pot(B{j}), i); | |
89 %temp = marginalize_pot_max(B{j}, rest); | |
90 temp = marginalize_pot(B{j}, rest, 1); | |
91 b = bucket_num(domain_pot(temp), order); | |
92 % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b); | |
93 sB=struct(B{b}); % violate object privacy | |
94 if ~isempty(sB.domain) | |
95 B{b} = multiply_pots(B{b}, temp); | |
96 else | |
97 B{b} = temp; | |
98 end | |
99 end | |
100 result = B{1}; | |
101 marginal = pot_to_marginal(result); | |
102 [prob, mpe] = max(marginal.T); | |
103 | |
104 % handle impossible cases | |
105 if ~(prob>0) | |
106 mpe = []; | |
107 ll = -inf; | |
108 %warning('evidence has zero probability') | |
109 return | |
110 end | |
111 | |
112 ll = log(prob); | |
113 | |
114 % Do forward phase | |
115 for ii=2:n | |
116 marginal = pot_to_marginal(B{ii}); | |
117 mpeidx = []; | |
118 for jj=order(1:length(mpe)) | |
119 assert(ismember(jj, marginal.domain)) %%% bug | |
120 temp = find_equiv_posns(jj, marginal.domain); | |
121 mpeidx = [mpeidx, temp] ; | |
122 if isempty(temp) | |
123 mpeidx = [mpeidx, Inf] ; | |
124 end | |
125 end | |
126 [mpeidxsorted sortedtompe] = sort(mpeidx) ; | |
127 | |
128 % maximize the matrix obtained from assigning values from previous buckets. | |
129 % this is done by building a string and using eval. | |
130 | |
131 kk=1; | |
132 sargs = '('; | |
133 for jj=1:length(marginal.domain) | |
134 if (jj~=1) | |
135 sargs = [sargs, ',']; | |
136 end | |
137 if (mpeidxsorted(kk)==jj) | |
138 sargs = [sargs, num2str(mpe(sortedtompe(kk)))]; | |
139 if (kk<length(mpe)) | |
140 kk = kk+1 ; | |
141 end | |
142 else | |
143 sargs = [sargs, ':']; | |
144 end | |
145 end | |
146 sargs = [sargs, ')'] ; | |
147 eval(['[val, loc] = max(marginal.T', sargs, ');']) | |
148 mpe = [mpe loc]; | |
149 end | |
150 [I,J] = sort(order); | |
151 mpe = mpe(J); | |
152 | |
153 | |
154 | |
155 %%%%%%%%% | |
156 | |
157 function b = bucket_num(domain, order) | |
158 | |
159 b = max(find_equiv_posns(domain, order)); | |
160 |