wolffd@0: function [new_bel, niter, new_msg, edge_id, nstates] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) wolffd@0: % BP_MRF2_GENERAL Belief propagation on an MRF with pairwise potentials wolffd@0: % function [bel, niter] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) wolffd@0: % wolffd@0: % Input: wolffd@0: % adj_mat(i,j) = 1 iff there is an edge between nodes i and j wolffd@0: % pot(ki,kj,i,j) or pot{i,j}(ki,kj) = potential on edge between nodes i,j wolffd@0: % If the potentials on all edges are the same, wolffd@0: % you can just pass in 1 array, pot(ki,kj) wolffd@0: % local_evidence(state, node) or local_evidence{i}(k) = Pr(observation at node i | Xi=k) wolffd@0: % wolffd@0: % Use cell arrays if the hidden nodes do not all have the same number of values. wolffd@0: % wolffd@0: % Output: wolffd@0: % bel(k,i) or bel{i}(k) = P(Xi=k|evidence) wolffd@0: % niter contains the number of iterations used wolffd@0: % wolffd@0: % [ ... ] = bp_mrf2(..., 'param1',val1, 'param2',val2, ...) wolffd@0: % allows you to specify optional parameters as name/value pairs. wolffd@0: % Parameters names are below [default value in brackets] wolffd@0: % wolffd@0: % max_iter - max. num. iterations [ 5*nnodes] wolffd@0: % momentum - weight assigned to old message in convex combination wolffd@0: % (useful for damping oscillations) - currently ignored i[0] wolffd@0: % tol - tolerance used to assess convergence [1e-3] wolffd@0: % maximize - 1 means use max-product, 0 means use sum-product [0] wolffd@0: % verbose - 1 means print error at every iteration [0] wolffd@0: % wolffd@0: % fn - name of function to call at end of every iteration [ [] ] wolffd@0: % fnargs - we call feval(fn, bel, iter, fnargs{:}) [ [] ] wolffd@0: wolffd@0: nnodes = length(adj_mat); wolffd@0: wolffd@0: [max_iter, momentum, tol, maximize, verbose, fn, fnargs] = ... wolffd@0: process_options(varargin, 'max_iter', 5*nnodes, 'momentum', 0, ... wolffd@0: 'tol', 1e-3, 'maximize', 0, 'verbose', 0, ... wolffd@0: 'fn', [], 'fnargs', []); wolffd@0: wolffd@0: if iscell(local_evidence) wolffd@0: use_cell = 1; wolffd@0: else wolffd@0: use_cell = 0; wolffd@0: [nstates nnodes] = size(local_evidence); wolffd@0: end wolffd@0: wolffd@0: if iscell(pot) wolffd@0: tied_pot = 0; wolffd@0: else wolffd@0: tied_pot = (ndims(pot)==2); wolffd@0: end wolffd@0: wolffd@0: wolffd@0: % give each edge a unique number wolffd@0: ndx = find(adj_mat); wolffd@0: nedges = length(ndx); wolffd@0: edge_id = zeros(1, nnodes*nnodes); wolffd@0: edge_id(ndx) = 1:nedges; wolffd@0: edge_id = reshape(edge_id, nnodes, nnodes); wolffd@0: wolffd@0: % initialise messages wolffd@0: if use_cell wolffd@0: prod_of_msgs = cell(1, nnodes); wolffd@0: old_bel = cell(1, nnodes); wolffd@0: nstates = zeros(1, nnodes); wolffd@0: old_msg = cell(1, nedges); wolffd@0: for i=1:nnodes wolffd@0: nstates(i) = length(local_evidence{i}); wolffd@0: prod_of_msgs{i} = local_evidence{i}; wolffd@0: old_bel{i} = local_evidence{i}; wolffd@0: end wolffd@0: for i=1:nnodes wolffd@0: nbrs = find(adj_mat(:,i)); wolffd@0: for j=nbrs(:)' wolffd@0: old_msg{edge_id(i,j)} = normalise(ones(nstates(j),1)); wolffd@0: end wolffd@0: end wolffd@0: else wolffd@0: prod_of_msgs = local_evidence; wolffd@0: old_bel = local_evidence; wolffd@0: %old_msg = zeros(nstates, nnodes, nnodes); wolffd@0: old_msg = zeros(nstates, nedges); wolffd@0: m = normalise(ones(nstates,1)); wolffd@0: for i=1:nnodes wolffd@0: nbrs = find(adj_mat(:,i)); wolffd@0: for j=nbrs(:)' wolffd@0: old_msg(:, edge_id(i,j)) = m; wolffd@0: %old_msg(:,i,j) = m; wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: wolffd@0: converged = 0; wolffd@0: iter = 1; wolffd@0: wolffd@0: while ~converged & (iter <= max_iter) wolffd@0: wolffd@0: % each node sends a msg to each of its neighbors wolffd@0: for i=1:nnodes wolffd@0: nbrs = find(adj_mat(i,:)); wolffd@0: for j=nbrs(:)' wolffd@0: if tied_pot wolffd@0: pot_ij = pot; wolffd@0: else wolffd@0: if iscell(pot) wolffd@0: pot_ij = pot{i,j}; wolffd@0: else wolffd@0: pot_ij = pot(:,:,i,j); wolffd@0: end wolffd@0: end wolffd@0: pot_ij = pot_ij'; % now pot_ij(xj, xi) wolffd@0: % so pot_ij * msg(xi) = sum_xi pot(xj,xi) msg(xi) = f(xj) wolffd@0: wolffd@0: if 1 wolffd@0: % Compute temp = product of all incoming msgs except from j wolffd@0: % by dividing out old msg from j from the product of all msgs sent to i wolffd@0: if use_cell wolffd@0: temp = prod_of_msgs{i}; wolffd@0: m = old_msg{edge_id(j,i)}; wolffd@0: else wolffd@0: temp = prod_of_msgs(:,i); wolffd@0: m = old_msg(:, edge_id(j,i)); wolffd@0: end wolffd@0: if any(m==0) wolffd@0: fprintf('iter=%d, send from i=%d to j=%d\n', iter, i, j); wolffd@0: keyboard wolffd@0: end wolffd@0: m = m + (m==0); % valid since m(k)=0 => temp(k)=0, so can replace 0's with anything wolffd@0: temp = temp ./ m; wolffd@0: temp_div = temp; wolffd@0: end wolffd@0: wolffd@0: if 1 wolffd@0: % Compute temp = product of all incoming msgs except from j in obvious way wolffd@0: if use_cell wolffd@0: %temp = ones(nstates(i),1); wolffd@0: temp = local_evidence{i}; wolffd@0: for k=nbrs(:)' wolffd@0: if k==j, continue, end; wolffd@0: temp = temp .* old_msg{edge_id(k,i)}; wolffd@0: end wolffd@0: else wolffd@0: %temp = ones(nstates,1); wolffd@0: temp = local_evidence(:,i); wolffd@0: for k=nbrs(:)' wolffd@0: if k==j, continue, end; wolffd@0: temp = temp .* old_msg(:, edge_id(k,i)); wolffd@0: end wolffd@0: end wolffd@0: end wolffd@0: %assert(approxeq(temp, temp_div)) wolffd@0: assert(approxeq(normalise(pot_ij * temp), normalise(pot_ij * temp_div))) wolffd@0: wolffd@0: if maximize wolffd@0: newm = max_mult(pot_ij, temp); % bottleneck wolffd@0: else wolffd@0: newm = pot_ij * temp; wolffd@0: end wolffd@0: newm = normalise(newm); wolffd@0: if use_cell wolffd@0: new_msg{edge_id(i,j)} = newm; wolffd@0: else wolffd@0: new_msg(:, edge_id(i,j)) = newm; wolffd@0: end wolffd@0: end % for j wolffd@0: end % for i wolffd@0: old_prod_of_msgs = prod_of_msgs; wolffd@0: wolffd@0: % each node multiplies all its incoming msgs and computes its local belief wolffd@0: if use_cell wolffd@0: for i=1:nnodes wolffd@0: nbrs = find(adj_mat(:,i)); wolffd@0: prod_of_msgs{i} = local_evidence{i}; wolffd@0: for j=nbrs(:)' wolffd@0: prod_of_msgs{i} = prod_of_msgs{i} .* new_msg{edge_id(j,i)}; wolffd@0: end wolffd@0: new_bel{i} = normalise(prod_of_msgs{i}); wolffd@0: end wolffd@0: err = abs(cat(1,new_bel{:}) - cat(1, old_bel{:})); wolffd@0: else wolffd@0: for i=1:nnodes wolffd@0: nbrs = find(adj_mat(:,i)); wolffd@0: prod_of_msgs(:,i) = local_evidence(:,i); wolffd@0: for j=nbrs(:)' wolffd@0: prod_of_msgs(:,i) = prod_of_msgs(:,i) .* new_msg(:,edge_id(j,i)); wolffd@0: end wolffd@0: new_bel(:,i) = normalise(prod_of_msgs(:,i)); wolffd@0: end wolffd@0: err = abs(new_bel(:) - old_bel(:)); wolffd@0: end wolffd@0: converged = all(err < tol); wolffd@0: if verbose, fprintf('error at iter %d = %f\n', iter, sum(err)); end wolffd@0: if ~isempty(fn) wolffd@0: if isempty(fnargs) wolffd@0: feval(fn, new_bel); wolffd@0: else wolffd@0: feval(fn, new_bel, iter, fnargs{:}); wolffd@0: end wolffd@0: end wolffd@0: wolffd@0: iter = iter + 1; wolffd@0: old_msg = new_msg; wolffd@0: old_bel = new_bel; wolffd@0: end % while wolffd@0: wolffd@0: niter = iter-1; wolffd@0: wolffd@0: fprintf('converged in %d iterations\n', niter); wolffd@0: