Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/general/Old/calc_mpe_bucket.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/general/Old/calc_mpe_bucket.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,160 @@ +function [mpe, ll] = calc_mpe_bucket(bnet, new_evidence, max_over) +% +% PURPOSE: +% CALC_MPE Computes the most probable explanation to the network nodes +% given the evidence. +% +% [mpe, ll] = calc_mpe(engine, new_evidence, max_over) +% +% INPUT: +% bnet - the bayesian network +% new_evidence - optional, if specified - evidence to be incorporated [cell(1,n)] +% max_over - optional, if specified determines the variable elimination order [1:n] +% +% OUTPUT: +% mpe - the MPE assignmet for the net variables (or [] if no satisfying assignment) +% ll - log assignment probability. +% +% Notes: +% 1. Adapted from '@var_elim_inf_engine\marginal_nodes' for MPE by Ron Zohar, 8/7/01 +% 2. Only discrete potentials are supported at this time. +% 3. Complexity: O(nw*) where n is the number of nodes and w* is the induced tree width. +% 4. Implementation based on: +% - R. Dechter, "Bucket Elimination: A Unifying Framework for Probabilistic Inference", +% UA1 96, pp. 211-219. + + +ns = bnet.node_sizes; +n = length(bnet.dag); +evidence = cell(1,n); +if (nargin<2) + new_evidence = evidence; +end + +onodes = find(~isemptycell(new_evidence)); % observed nodes +hnodes = find(isemptycell(new_evidence)); % hidden nodes +pot_type = determine_pot_type(bnet, onodes); + +if pot_type ~= 'd' + error('only disrete potentials supported at this time') +end + +for i=1:n + fam = family(bnet.dag, i); + CPT{i} = convert_to_pot(bnet.CPD{bnet.equiv_class(i)}, pot_type, fam(:), evidence); +end + +% handle observed nodes: set impossible cases' probability to zero +% rather than prun matrix (this makes backtracking easier) + +for ii=onodes + lIdx = 1:ns(ii); + lIdx = setdiff(lIdx, new_evidence{ii}); + + sCPT=struct(CPT{ii}); % violate object privacy + + sargs = ''; + for jj=1:(length(sCPT.domain)-1) + sargs = [sargs, ':,']; + end + for jj=lIdx + eval(['sCPT.T(', sargs, num2str(jj), ')=0;']); + end + CPT{ii}=dpot(sCPT.domain, sCPT.sizes, sCPT.T); +end + +B = cell(1,n); +for b=1:n + B{b} = mk_initial_pot(pot_type, [], [], [], []); +end + +if (nargin<3) + max_over = (1:n); +end +order = max_over; % no attempt to optimize this + + +% Initialize the buckets with the CPDs assigned to them +for i=1:n + b = bucket_num(domain_pot(CPT{i}), order); + B{b} = multiply_pots(B{b}, CPT{i}); +end + +% Do backward phase +max_over = max_over(length(max_over):-1:1); % reverse +for i=max_over(1:end-1) + % max-ing over variable i which occurs in bucket j + j = bucket_num(i, order); + rest = mysetdiff(domain_pot(B{j}), i); + %temp = marginalize_pot_max(B{j}, rest); + temp = marginalize_pot(B{j}, rest, 1); + b = bucket_num(domain_pot(temp), order); + % fprintf('maxing over bucket %d (var %d), putting result into bucket %d\n', j, i, b); + sB=struct(B{b}); % violate object privacy + if ~isempty(sB.domain) + B{b} = multiply_pots(B{b}, temp); + else + B{b} = temp; + end +end +result = B{1}; +marginal = pot_to_marginal(result); +[prob, mpe] = max(marginal.T); + +% handle impossible cases +if ~(prob>0) + mpe = []; + ll = -inf; + %warning('evidence has zero probability') + return +end + +ll = log(prob); + +% Do forward phase +for ii=2:n + marginal = pot_to_marginal(B{ii}); + mpeidx = []; + for jj=order(1:length(mpe)) + assert(ismember(jj, marginal.domain)) %%% bug + temp = find_equiv_posns(jj, marginal.domain); + mpeidx = [mpeidx, temp] ; + if isempty(temp) + mpeidx = [mpeidx, Inf] ; + end + end + [mpeidxsorted sortedtompe] = sort(mpeidx) ; + + % maximize the matrix obtained from assigning values from previous buckets. + % this is done by building a string and using eval. + + kk=1; + sargs = '('; + for jj=1:length(marginal.domain) + if (jj~=1) + sargs = [sargs, ',']; + end + if (mpeidxsorted(kk)==jj) + sargs = [sargs, num2str(mpe(sortedtompe(kk)))]; + if (kk<length(mpe)) + kk = kk+1 ; + end + else + sargs = [sargs, ':']; + end + end + sargs = [sargs, ')'] ; + eval(['[val, loc] = max(marginal.T', sargs, ');']) + mpe = [mpe loc]; +end +[I,J] = sort(order); +mpe = mpe(J); + + + +%%%%%%%%% + +function b = bucket_num(domain, order) + +b = max(find_equiv_posns(domain, order)); +