annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@var_elim_inf_engine/find_mpe.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function mpe = find_mpe(engine, new_evidence, max_over)
wolffd@0 2 % FIND_MPE Find the most probable explanation of the data (assignment to the hidden nodes)
wolffd@0 3 % function mpe = find_mpe(engine, evidence, order)
wolffd@0 4 %
wolffd@0 5 % PURPOSE:
wolffd@0 6 % CALC_MPE Computes the most probable explanation to the network nodes
wolffd@0 7 % given the evidence.
wolffd@0 8 %
wolffd@0 9 % [mpe, ll] = calc_mpe(engine, new_evidence, max_over)
wolffd@0 10 %
wolffd@0 11 % INPUT:
wolffd@0 12 % bnet - the bayesian network
wolffd@0 13 % new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)]
wolffd@0 14 % max_over - optional, if specified determines the variable elimination order [1:n]
wolffd@0 15 %
wolffd@0 16 % OUTPUT:
wolffd@0 17 % mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment)
wolffd@0 18 % ll - log assignment probability.
wolffd@0 19 %
wolffd@0 20 % Notes:
wolffd@0 21 % 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01
wolffd@0 22 % 2. Only discrete potentials are supported at this time.
wolffd@0 23 % 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width.
wolffd@0 24 % 4. Implementation based on:
wolffd@0 25 % - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference",
wolffd@0 26 % UA1 96, pp. 211-219.
wolffd@0 27
wolffd@0 28 bnet = bnet_from_engine(engine);
wolffd@0 29 ns = bnet.node_sizes;
wolffd@0 30 n = length(bnet.dag);
wolffd@0 31 evidence = cell(1,n);
wolffd@0 32 if (nargin<2)
wolffd@0 33 new_evidence = evidence;
wolffd@0 34 end
wolffd@0 35
wolffd@0 36 onodes = find(~isemptycell(new_evidence)); % observed nodes
wolffd@0 37 hnodes = find(isemptycell(new_evidence)); % hidden nodes
wolffd@0 38 pot_type = determine_pot_type(bnet, onodes);
wolffd@0 39
wolffd@0 40 if pot_type ~= 'd'
wolffd@0 41 error('only disrete potentials supported at this time')
wolffd@0 42 end
wolffd@0 43
wolffd@0 44 for i=1:n
wolffd@0 45 fam = family(bnet.dag, i);
wolffd@0 46 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
wolffd@0 47 end
wolffd@0 48
wolffd@0 49 % handle observed nodes: set impossible cases' probability to zero
wolffd@0 50 % rather than prun matrix (this makes backtracking easier)
wolffd@0 51
wolffd@0 52 for ii=onodes
wolffd@0 53 lIdx = 1:ns(ii);
wolffd@0 54 lIdx = setdiff(lIdx, new_evidence{ii});
wolffd@0 55
wolffd@0 56 sCPT=struct(CPT{ii}); % violate object privacy
wolffd@0 57
wolffd@0 58 sargs = '';
wolffd@0 59 for jj=1:(length(sCPT.domain)-1)
wolffd@0 60 sargs = [sargs, ':,'];
wolffd@0 61 end
wolffd@0 62 for jj=lIdx
wolffd@0 63 eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
wolffd@0 64 end
wolffd@0 65 CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);
wolffd@0 66 end
wolffd@0 67
wolffd@0 68 B = cell(1,n);
wolffd@0 69 for b=1:n
wolffd@0 70 B{b} = mk_initial_pot(pot_type, [], [], [], []);
wolffd@0 71 end
wolffd@0 72
wolffd@0 73 if (nargin<3)
wolffd@0 74 max_over = (1:n);
wolffd@0 75 end
wolffd@0 76 order = max_over; % no attempt to optimize this
wolffd@0 77
wolffd@0 78
wolffd@0 79 % Initialize the buckets with the CPDs assigned to them
wolffd@0 80 for i=1:n
wolffd@0 81 b = bucket_num(domain_pot(CPT{i}), order);
wolffd@0 82 B{b} = multiply_pots(B{b}, CPT{i});
wolffd@0 83 end
wolffd@0 84
wolffd@0 85 % Do backward phase
wolffd@0 86 max_over = max_over(length(max_over):-1:1); % reverse
wolffd@0 87 maximize = 1;
wolffd@0 88 for i=max_over(1:end-1)
wolffd@0 89 % max-ing over variable i which occurs in bucket j
wolffd@0 90 j = bucket_num(i, order);
wolffd@0 91 rest = mysetdiff(domain_pot(B{j}), i);
wolffd@0 92 %temp = marginalize_pot_max(B{j}, rest);
wolffd@0 93 temp = marginalize_pot(B{j}, rest, maximize);
wolffd@0 94 b = bucket_num(domain_pot(temp), order);
wolffd@0 95 % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
wolffd@0 96 sB=struct(B{b}); % violate object privacy
wolffd@0 97 if ~isempty(sB.domain)
wolffd@0 98 B{b} = multiply_pots(B{b}, temp);
wolffd@0 99 else
wolffd@0 100 B{b} = temp;
wolffd@0 101 end
wolffd@0 102 end
wolffd@0 103 result = B{1};
wolffd@0 104 marginal = pot_to_marginal(result);
wolffd@0 105 [prob, mpe] = max(marginal.T);
wolffd@0 106
wolffd@0 107 % handle impossible cases
wolffd@0 108 if ~(prob>0)
wolffd@0 109 mpe = [];
wolffd@0 110 ll = -inf;
wolffd@0 111 %warning('evidence has zero probability')
wolffd@0 112 return
wolffd@0 113 end
wolffd@0 114
wolffd@0 115 ll = log(prob);
wolffd@0 116
wolffd@0 117 % Do forward phase
wolffd@0 118 for ii=2:n
wolffd@0 119 marginal = pot_to_marginal(B{ii});
wolffd@0 120 mpeidx = [];
wolffd@0 121 for jj=order(1:length(mpe))
wolffd@0 122 %assert(ismember(jj, marginal.domain)) %%% bug
wolffd@0 123 temp = find_equiv_posns(jj, marginal.domain);
wolffd@0 124 mpeidx = [mpeidx, temp] ;
wolffd@0 125 if isempty(temp)
wolffd@0 126 mpeidx = [mpeidx, Inf] ;
wolffd@0 127 end
wolffd@0 128 end
wolffd@0 129 [mpeidxsorted sortedtompe] = sort(mpeidx) ;
wolffd@0 130
wolffd@0 131 % maximize the matrix obtained from assigning values from previous buckets.
wolffd@0 132 % this is done by building a string and using eval.
wolffd@0 133
wolffd@0 134 kk=1;
wolffd@0 135 sargs = '(';
wolffd@0 136 for jj=1:length(marginal.domain)
wolffd@0 137 if (jj~=1)
wolffd@0 138 sargs = [sargs, ','];
wolffd@0 139 end
wolffd@0 140 if (mpeidxsorted(kk)==jj)
wolffd@0 141 sargs = [sargs, num2str(mpe(sortedtompe(kk)))];
wolffd@0 142 if (kk<length(mpe))
wolffd@0 143 kk = kk+1 ;
wolffd@0 144 end
wolffd@0 145 else
wolffd@0 146 sargs = [sargs, ':'];
wolffd@0 147 end
wolffd@0 148 end
wolffd@0 149 sargs = [sargs, ')'] ;
wolffd@0 150 eval(['[val, loc] = max(marginal.T', sargs, ');'])
wolffd@0 151 mpe = [mpe loc];
wolffd@0 152 end
wolffd@0 153 [I,J] = sort(order);
wolffd@0 154 mpe = mpe(J);
wolffd@0 155
wolffd@0 156 mpe = num2cell(mpe);
wolffd@0 157
wolffd@0 158 %%%%%%%%%
wolffd@0 159
wolffd@0 160 function b = bucket_num(domain, order)
wolffd@0 161
wolffd@0 162 b = max(find_equiv_posns(domain, order));
wolffd@0 163