Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_mrf2_inf_engine/bp_mrf2.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_mrf2_inf_engine/bp_mrf2.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,209 @@ +function [new_bel, niter, new_msg, edge_id, nstates] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) +% BP_MRF2_GENERAL Belief propagation on an MRF with pairwise potentials +% function [bel, niter] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) +% +% Input: +% adj_mat(i,j) = 1 iff there is an edge between nodes i and j +% pot(ki,kj,i,j) or pot{i,j}(ki,kj) = potential on edge between nodes i,j +% If the potentials on all edges are the same, +% you can just pass in 1 array, pot(ki,kj) +% local_evidence(state, node) or local_evidence{i}(k) = Pr(observation at node i | Xi=k) +% +% Use cell arrays if the hidden nodes do not all have the same number of values. +% +% Output: +% bel(k,i) or bel{i}(k) = P(Xi=k|evidence) +% niter contains the number of iterations used +% +% [ ... ] = bp_mrf2(..., 'param1',val1, 'param2',val2, ...) +% allows you to specify optional parameters as name/value pairs. +% Parameters names are below [default value in brackets] +% +% max_iter - max. num. iterations [ 5*nnodes] +% momentum - weight assigned to old message in convex combination +% (useful for damping oscillations) - currently ignored i[0] +% tol - tolerance used to assess convergence [1e-3] +% maximize - 1 means use max-product, 0 means use sum-product [0] +% verbose - 1 means print error at every iteration [0] +% +% fn - name of function to call at end of every iteration [ [] ] +% fnargs - we call feval(fn, bel, iter, fnargs{:}) [ [] ] + +nnodes = length(adj_mat); + +[max_iter, momentum, tol, maximize, verbose, fn, fnargs] = ... + process_options(varargin, 'max_iter', 5*nnodes, 'momentum', 0, ... + 'tol', 1e-3, 'maximize', 0, 'verbose', 0, ... + 'fn', [], 'fnargs', []); + +if iscell(local_evidence) + use_cell = 1; +else + use_cell = 0; + [nstates nnodes] = size(local_evidence); +end + +if iscell(pot) + tied_pot = 0; +else + tied_pot = (ndims(pot)==2); +end + + +% give each edge a unique number +ndx = find(adj_mat); +nedges = length(ndx); +edge_id = zeros(1, nnodes*nnodes); +edge_id(ndx) = 1:nedges; +edge_id = reshape(edge_id, nnodes, nnodes); + +% initialise messages +if use_cell + prod_of_msgs = cell(1, nnodes); + old_bel = cell(1, nnodes); + nstates = zeros(1, nnodes); + old_msg = cell(1, nedges); + for i=1:nnodes + nstates(i) = length(local_evidence{i}); + prod_of_msgs{i} = local_evidence{i}; + old_bel{i} = local_evidence{i}; + end + for i=1:nnodes + nbrs = find(adj_mat(:,i)); + for j=nbrs(:)' + old_msg{edge_id(i,j)} = normalise(ones(nstates(j),1)); + end + end +else + prod_of_msgs = local_evidence; + old_bel = local_evidence; + %old_msg = zeros(nstates, nnodes, nnodes); + old_msg = zeros(nstates, nedges); + m = normalise(ones(nstates,1)); + for i=1:nnodes + nbrs = find(adj_mat(:,i)); + for j=nbrs(:)' + old_msg(:, edge_id(i,j)) = m; + %old_msg(:,i,j) = m; + end + end +end + + +converged = 0; +iter = 1; + +while ~converged & (iter <= max_iter) + + % each node sends a msg to each of its neighbors + for i=1:nnodes + nbrs = find(adj_mat(i,:)); + for j=nbrs(:)' + if tied_pot + pot_ij = pot; + else + if iscell(pot) + pot_ij = pot{i,j}; + else + pot_ij = pot(:,:,i,j); + end + end + pot_ij = pot_ij'; % now pot_ij(xj, xi) + % so pot_ij * msg(xi) = sum_xi pot(xj,xi) msg(xi) = f(xj) + + if 1 + % Compute temp = product of all incoming msgs except from j + % by dividing out old msg from j from the product of all msgs sent to i + if use_cell + temp = prod_of_msgs{i}; + m = old_msg{edge_id(j,i)}; + else + temp = prod_of_msgs(:,i); + m = old_msg(:, edge_id(j,i)); + end + if any(m==0) + fprintf('iter=%d, send from i=%d to j=%d\n', iter, i, j); + keyboard + end + m = m + (m==0); % valid since m(k)=0 => temp(k)=0, so can replace 0's with anything + temp = temp ./ m; + temp_div = temp; + end + + if 1 + % Compute temp = product of all incoming msgs except from j in obvious way + if use_cell + %temp = ones(nstates(i),1); + temp = local_evidence{i}; + for k=nbrs(:)' + if k==j, continue, end; + temp = temp .* old_msg{edge_id(k,i)}; + end + else + %temp = ones(nstates,1); + temp = local_evidence(:,i); + for k=nbrs(:)' + if k==j, continue, end; + temp = temp .* old_msg(:, edge_id(k,i)); + end + end + end + %assert(approxeq(temp, temp_div)) + assert(approxeq(normalise(pot_ij * temp), normalise(pot_ij * temp_div))) + + if maximize + newm = max_mult(pot_ij, temp); % bottleneck + else + newm = pot_ij * temp; + end + newm = normalise(newm); + if use_cell + new_msg{edge_id(i,j)} = newm; + else + new_msg(:, edge_id(i,j)) = newm; + end + end % for j + end % for i + old_prod_of_msgs = prod_of_msgs; + + % each node multiplies all its incoming msgs and computes its local belief + if use_cell + for i=1:nnodes + nbrs = find(adj_mat(:,i)); + prod_of_msgs{i} = local_evidence{i}; + for j=nbrs(:)' + prod_of_msgs{i} = prod_of_msgs{i} .* new_msg{edge_id(j,i)}; + end + new_bel{i} = normalise(prod_of_msgs{i}); + end + err = abs(cat(1,new_bel{:}) - cat(1, old_bel{:})); + else + for i=1:nnodes + nbrs = find(adj_mat(:,i)); + prod_of_msgs(:,i) = local_evidence(:,i); + for j=nbrs(:)' + prod_of_msgs(:,i) = prod_of_msgs(:,i) .* new_msg(:,edge_id(j,i)); + end + new_bel(:,i) = normalise(prod_of_msgs(:,i)); + end + err = abs(new_bel(:) - old_bel(:)); + end + converged = all(err < tol); + if verbose, fprintf('error at iter %d = %f\n', iter, sum(err)); end + if ~isempty(fn) + if isempty(fnargs) + feval(fn, new_bel); + else + feval(fn, new_bel, iter, fnargs{:}); + end + end + + iter = iter + 1; + old_msg = new_msg; + old_bel = new_bel; +end % while + +niter = iter-1; + +fprintf('converged in %d iterations\n', niter); +