annotate toolboxes/FullBNT-1.0.7/bnt/general/Old/calc_mpe_bucket.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over)
wolffd@0 2 %
wolffd@0 3 % PURPOSE:
wolffd@0 4 % CALC_MPE Computes the most probable explanation to the network nodes
wolffd@0 5 % given the evidence.
wolffd@0 6 %
wolffd@0 7 % [mpe, ll] = calc_mpe(engine, new_evidence, max_over)
wolffd@0 8 %
wolffd@0 9 % INPUT:
wolffd@0 10 % bnet - the bayesian network
wolffd@0 11 % new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)]
wolffd@0 12 % max_over - optional, if specified determines the variable elimination order [1:n]
wolffd@0 13 %
wolffd@0 14 % OUTPUT:
wolffd@0 15 % mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment)
wolffd@0 16 % ll - log assignment probability.
wolffd@0 17 %
wolffd@0 18 % Notes:
wolffd@0 19 % 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01
wolffd@0 20 % 2. Only discrete potentials are supported at this time.
wolffd@0 21 % 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width.
wolffd@0 22 % 4. Implementation based on:
wolffd@0 23 % - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference",
wolffd@0 24 % UA1 96, pp. 211-219.
wolffd@0 25
wolffd@0 26
wolffd@0 27 ns = bnet.node_sizes;
wolffd@0 28 n = length(bnet.dag);
wolffd@0 29 evidence = cell(1,n);
wolffd@0 30 if (nargin<2)
wolffd@0 31 new_evidence = evidence;
wolffd@0 32 end
wolffd@0 33
wolffd@0 34 onodes = find(~isemptycell(new_evidence)); % observed nodes
wolffd@0 35 hnodes = find(isemptycell(new_evidence)); % hidden nodes
wolffd@0 36 pot_type = determine_pot_type(bnet, onodes);
wolffd@0 37
wolffd@0 38 if pot_type ~= 'd'
wolffd@0 39 error('only disrete potentials supported at this time')
wolffd@0 40 end
wolffd@0 41
wolffd@0 42 for i=1:n
wolffd@0 43 fam = family(bnet.dag, i);
wolffd@0 44 CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence);
wolffd@0 45 end
wolffd@0 46
wolffd@0 47 % handle observed nodes: set impossible cases' probability to zero
wolffd@0 48 % rather than prun matrix (this makes backtracking easier)
wolffd@0 49
wolffd@0 50 for ii=onodes
wolffd@0 51 lIdx = 1:ns(ii);
wolffd@0 52 lIdx = setdiff(lIdx, new_evidence{ii});
wolffd@0 53
wolffd@0 54 sCPT=struct(CPT{ii}); % violate object privacy
wolffd@0 55
wolffd@0 56 sargs = '';
wolffd@0 57 for jj=1:(length(sCPT.domain)-1)
wolffd@0 58 sargs = [sargs, ':,'];
wolffd@0 59 end
wolffd@0 60 for jj=lIdx
wolffd@0 61 eval(['sCPT.T(', sargs, num2str(jj), ')=0;']);
wolffd@0 62 end
wolffd@0 63 CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T);
wolffd@0 64 end
wolffd@0 65
wolffd@0 66 B = cell(1,n);
wolffd@0 67 for b=1:n
wolffd@0 68 B{b} = mk_initial_pot(pot_type, [], [], [], []);
wolffd@0 69 end
wolffd@0 70
wolffd@0 71 if (nargin<3)
wolffd@0 72 max_over = (1:n);
wolffd@0 73 end
wolffd@0 74 order = max_over; % no attempt to optimize this
wolffd@0 75
wolffd@0 76
wolffd@0 77 % Initialize the buckets with the CPDs assigned to them
wolffd@0 78 for i=1:n
wolffd@0 79 b = bucket_num(domain_pot(CPT{i}), order);
wolffd@0 80 B{b} = multiply_pots(B{b}, CPT{i});
wolffd@0 81 end
wolffd@0 82
wolffd@0 83 % Do backward phase
wolffd@0 84 max_over = max_over(length(max_over):-1:1); % reverse
wolffd@0 85 for i=max_over(1:end-1)
wolffd@0 86 % max-ing over variable i which occurs in bucket j
wolffd@0 87 j = bucket_num(i, order);
wolffd@0 88 rest = mysetdiff(domain_pot(B{j}), i);
wolffd@0 89 %temp = marginalize_pot_max(B{j}, rest);
wolffd@0 90 temp = marginalize_pot(B{j}, rest, 1);
wolffd@0 91 b = bucket_num(domain_pot(temp), order);
wolffd@0 92 % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b);
wolffd@0 93 sB=struct(B{b}); % violate object privacy
wolffd@0 94 if ~isempty(sB.domain)
wolffd@0 95 B{b} = multiply_pots(B{b}, temp);
wolffd@0 96 else
wolffd@0 97 B{b} = temp;
wolffd@0 98 end
wolffd@0 99 end
wolffd@0 100 result = B{1};
wolffd@0 101 marginal = pot_to_marginal(result);
wolffd@0 102 [prob, mpe] = max(marginal.T);
wolffd@0 103
wolffd@0 104 % handle impossible cases
wolffd@0 105 if ~(prob>0)
wolffd@0 106 mpe = [];
wolffd@0 107 ll = -inf;
wolffd@0 108 %warning('evidence has zero probability')
wolffd@0 109 return
wolffd@0 110 end
wolffd@0 111
wolffd@0 112 ll = log(prob);
wolffd@0 113
wolffd@0 114 % Do forward phase
wolffd@0 115 for ii=2:n
wolffd@0 116 marginal = pot_to_marginal(B{ii});
wolffd@0 117 mpeidx = [];
wolffd@0 118 for jj=order(1:length(mpe))
wolffd@0 119 assert(ismember(jj, marginal.domain)) %%% bug
wolffd@0 120 temp = find_equiv_posns(jj, marginal.domain);
wolffd@0 121 mpeidx = [mpeidx, temp] ;
wolffd@0 122 if isempty(temp)
wolffd@0 123 mpeidx = [mpeidx, Inf] ;
wolffd@0 124 end
wolffd@0 125 end
wolffd@0 126 [mpeidxsorted sortedtompe] = sort(mpeidx) ;
wolffd@0 127
wolffd@0 128 % maximize the matrix obtained from assigning values from previous buckets.
wolffd@0 129 % this is done by building a string and using eval.
wolffd@0 130
wolffd@0 131 kk=1;
wolffd@0 132 sargs = '(';
wolffd@0 133 for jj=1:length(marginal.domain)
wolffd@0 134 if (jj~=1)
wolffd@0 135 sargs = [sargs, ','];
wolffd@0 136 end
wolffd@0 137 if (mpeidxsorted(kk)==jj)
wolffd@0 138 sargs = [sargs, num2str(mpe(sortedtompe(kk)))];
wolffd@0 139 if (kk<length(mpe))
wolffd@0 140 kk = kk+1 ;
wolffd@0 141 end
wolffd@0 142 else
wolffd@0 143 sargs = [sargs, ':'];
wolffd@0 144 end
wolffd@0 145 end
wolffd@0 146 sargs = [sargs, ')'] ;
wolffd@0 147 eval(['[val, loc] = max(marginal.T', sargs, ');'])
wolffd@0 148 mpe = [mpe loc];
wolffd@0 149 end
wolffd@0 150 [I,J] = sort(order);
wolffd@0 151 mpe = mpe(J);
wolffd@0 152
wolffd@0 153
wolffd@0 154
wolffd@0 155 %%%%%%%%%
wolffd@0 156
wolffd@0 157 function b = bucket_num(domain, order)
wolffd@0 158
wolffd@0 159 b = max(find_equiv_posns(domain, order));
wolffd@0 160