annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_mrf2_inf_engine/bp_mrf2.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 function [new_bel, niter, new_msg, edge_id, nstates] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin)
wolffd@0 2 % BP_MRF2_GENERAL Belief propagation on an MRF with pairwise potentials
wolffd@0 3 % function [bel, niter] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin)
wolffd@0 4 %
wolffd@0 5 % Input:
wolffd@0 6 % adj_mat(i,j) = 1 iff there is an edge between nodes i and j
wolffd@0 7 % pot(ki,kj,i,j) or pot{i,j}(ki,kj) = potential on edge between nodes i,j
wolffd@0 8 % If the potentials on all edges are the same,
wolffd@0 9 % you can just pass in 1 array, pot(ki,kj)
wolffd@0 10 % local_evidence(state, node) or local_evidence{i}(k) = Pr(observation at node i | Xi=k)
wolffd@0 11 %
wolffd@0 12 % Use cell arrays if the hidden nodes do not all have the same number of values.
wolffd@0 13 %
wolffd@0 14 % Output:
wolffd@0 15 % bel(k,i) or bel{i}(k) = P(Xi=k|evidence)
wolffd@0 16 % niter contains the number of iterations used
wolffd@0 17 %
wolffd@0 18 % [ ... ] = bp_mrf2(..., 'param1',val1, 'param2',val2, ...)
wolffd@0 19 % allows you to specify optional parameters as name/value pairs.
wolffd@0 20 % Parameters names are below [default value in brackets]
wolffd@0 21 %
wolffd@0 22 % max_iter - max. num. iterations [ 5*nnodes]
wolffd@0 23 % momentum - weight assigned to old message in convex combination
wolffd@0 24 % (useful for damping oscillations) - currently ignored i[0]
wolffd@0 25 % tol - tolerance used to assess convergence [1e-3]
wolffd@0 26 % maximize - 1 means use max-product, 0 means use sum-product [0]
wolffd@0 27 % verbose - 1 means print error at every iteration [0]
wolffd@0 28 %
wolffd@0 29 % fn - name of function to call at end of every iteration [ [] ]
wolffd@0 30 % fnargs - we call feval(fn, bel, iter, fnargs{:}) [ [] ]
wolffd@0 31
wolffd@0 32 nnodes = length(adj_mat);
wolffd@0 33
wolffd@0 34 [max_iter, momentum, tol, maximize, verbose, fn, fnargs] = ...
wolffd@0 35 process_options(varargin, 'max_iter', 5*nnodes, 'momentum', 0, ...
wolffd@0 36 'tol', 1e-3, 'maximize', 0, 'verbose', 0, ...
wolffd@0 37 'fn', [], 'fnargs', []);
wolffd@0 38
wolffd@0 39 if iscell(local_evidence)
wolffd@0 40 use_cell = 1;
wolffd@0 41 else
wolffd@0 42 use_cell = 0;
wolffd@0 43 [nstates nnodes] = size(local_evidence);
wolffd@0 44 end
wolffd@0 45
wolffd@0 46 if iscell(pot)
wolffd@0 47 tied_pot = 0;
wolffd@0 48 else
wolffd@0 49 tied_pot = (ndims(pot)==2);
wolffd@0 50 end
wolffd@0 51
wolffd@0 52
wolffd@0 53 % give each edge a unique number
wolffd@0 54 ndx = find(adj_mat);
wolffd@0 55 nedges = length(ndx);
wolffd@0 56 edge_id = zeros(1, nnodes*nnodes);
wolffd@0 57 edge_id(ndx) = 1:nedges;
wolffd@0 58 edge_id = reshape(edge_id, nnodes, nnodes);
wolffd@0 59
wolffd@0 60 % initialise messages
wolffd@0 61 if use_cell
wolffd@0 62 prod_of_msgs = cell(1, nnodes);
wolffd@0 63 old_bel = cell(1, nnodes);
wolffd@0 64 nstates = zeros(1, nnodes);
wolffd@0 65 old_msg = cell(1, nedges);
wolffd@0 66 for i=1:nnodes
wolffd@0 67 nstates(i) = length(local_evidence{i});
wolffd@0 68 prod_of_msgs{i} = local_evidence{i};
wolffd@0 69 old_bel{i} = local_evidence{i};
wolffd@0 70 end
wolffd@0 71 for i=1:nnodes
wolffd@0 72 nbrs = find(adj_mat(:,i));
wolffd@0 73 for j=nbrs(:)'
wolffd@0 74 old_msg{edge_id(i,j)} = normalise(ones(nstates(j),1));
wolffd@0 75 end
wolffd@0 76 end
wolffd@0 77 else
wolffd@0 78 prod_of_msgs = local_evidence;
wolffd@0 79 old_bel = local_evidence;
wolffd@0 80 %old_msg = zeros(nstates, nnodes, nnodes);
wolffd@0 81 old_msg = zeros(nstates, nedges);
wolffd@0 82 m = normalise(ones(nstates,1));
wolffd@0 83 for i=1:nnodes
wolffd@0 84 nbrs = find(adj_mat(:,i));
wolffd@0 85 for j=nbrs(:)'
wolffd@0 86 old_msg(:, edge_id(i,j)) = m;
wolffd@0 87 %old_msg(:,i,j) = m;
wolffd@0 88 end
wolffd@0 89 end
wolffd@0 90 end
wolffd@0 91
wolffd@0 92
wolffd@0 93 converged = 0;
wolffd@0 94 iter = 1;
wolffd@0 95
wolffd@0 96 while ~converged & (iter <= max_iter)
wolffd@0 97
wolffd@0 98 % each node sends a msg to each of its neighbors
wolffd@0 99 for i=1:nnodes
wolffd@0 100 nbrs = find(adj_mat(i,:));
wolffd@0 101 for j=nbrs(:)'
wolffd@0 102 if tied_pot
wolffd@0 103 pot_ij = pot;
wolffd@0 104 else
wolffd@0 105 if iscell(pot)
wolffd@0 106 pot_ij = pot{i,j};
wolffd@0 107 else
wolffd@0 108 pot_ij = pot(:,:,i,j);
wolffd@0 109 end
wolffd@0 110 end
wolffd@0 111 pot_ij = pot_ij'; % now pot_ij(xj, xi)
wolffd@0 112 % so pot_ij * msg(xi) = sum_xi pot(xj,xi) msg(xi) = f(xj)
wolffd@0 113
wolffd@0 114 if 1
wolffd@0 115 % Compute temp = product of all incoming msgs except from j
wolffd@0 116 % by dividing out old msg from j from the product of all msgs sent to i
wolffd@0 117 if use_cell
wolffd@0 118 temp = prod_of_msgs{i};
wolffd@0 119 m = old_msg{edge_id(j,i)};
wolffd@0 120 else
wolffd@0 121 temp = prod_of_msgs(:,i);
wolffd@0 122 m = old_msg(:, edge_id(j,i));
wolffd@0 123 end
wolffd@0 124 if any(m==0)
wolffd@0 125 fprintf('iter=%d, send from i=%d to j=%d\n', iter, i, j);
wolffd@0 126 keyboard
wolffd@0 127 end
wolffd@0 128 m = m + (m==0); % valid since m(k)=0 => temp(k)=0, so can replace 0's with anything
wolffd@0 129 temp = temp ./ m;
wolffd@0 130 temp_div = temp;
wolffd@0 131 end
wolffd@0 132
wolffd@0 133 if 1
wolffd@0 134 % Compute temp = product of all incoming msgs except from j in obvious way
wolffd@0 135 if use_cell
wolffd@0 136 %temp = ones(nstates(i),1);
wolffd@0 137 temp = local_evidence{i};
wolffd@0 138 for k=nbrs(:)'
wolffd@0 139 if k==j, continue, end;
wolffd@0 140 temp = temp .* old_msg{edge_id(k,i)};
wolffd@0 141 end
wolffd@0 142 else
wolffd@0 143 %temp = ones(nstates,1);
wolffd@0 144 temp = local_evidence(:,i);
wolffd@0 145 for k=nbrs(:)'
wolffd@0 146 if k==j, continue, end;
wolffd@0 147 temp = temp .* old_msg(:, edge_id(k,i));
wolffd@0 148 end
wolffd@0 149 end
wolffd@0 150 end
wolffd@0 151 %assert(approxeq(temp, temp_div))
wolffd@0 152 assert(approxeq(normalise(pot_ij * temp), normalise(pot_ij * temp_div)))
wolffd@0 153
wolffd@0 154 if maximize
wolffd@0 155 newm = max_mult(pot_ij, temp); % bottleneck
wolffd@0 156 else
wolffd@0 157 newm = pot_ij * temp;
wolffd@0 158 end
wolffd@0 159 newm = normalise(newm);
wolffd@0 160 if use_cell
wolffd@0 161 new_msg{edge_id(i,j)} = newm;
wolffd@0 162 else
wolffd@0 163 new_msg(:, edge_id(i,j)) = newm;
wolffd@0 164 end
wolffd@0 165 end % for j
wolffd@0 166 end % for i
wolffd@0 167 old_prod_of_msgs = prod_of_msgs;
wolffd@0 168
wolffd@0 169 % each node multiplies all its incoming msgs and computes its local belief
wolffd@0 170 if use_cell
wolffd@0 171 for i=1:nnodes
wolffd@0 172 nbrs = find(adj_mat(:,i));
wolffd@0 173 prod_of_msgs{i} = local_evidence{i};
wolffd@0 174 for j=nbrs(:)'
wolffd@0 175 prod_of_msgs{i} = prod_of_msgs{i} .* new_msg{edge_id(j,i)};
wolffd@0 176 end
wolffd@0 177 new_bel{i} = normalise(prod_of_msgs{i});
wolffd@0 178 end
wolffd@0 179 err = abs(cat(1,new_bel{:}) - cat(1, old_bel{:}));
wolffd@0 180 else
wolffd@0 181 for i=1:nnodes
wolffd@0 182 nbrs = find(adj_mat(:,i));
wolffd@0 183 prod_of_msgs(:,i) = local_evidence(:,i);
wolffd@0 184 for j=nbrs(:)'
wolffd@0 185 prod_of_msgs(:,i) = prod_of_msgs(:,i) .* new_msg(:,edge_id(j,i));
wolffd@0 186 end
wolffd@0 187 new_bel(:,i) = normalise(prod_of_msgs(:,i));
wolffd@0 188 end
wolffd@0 189 err = abs(new_bel(:) - old_bel(:));
wolffd@0 190 end
wolffd@0 191 converged = all(err < tol);
wolffd@0 192 if verbose, fprintf('error at iter %d = %f\n', iter, sum(err)); end
wolffd@0 193 if ~isempty(fn)
wolffd@0 194 if isempty(fnargs)
wolffd@0 195 feval(fn, new_bel);
wolffd@0 196 else
wolffd@0 197 feval(fn, new_bel, iter, fnargs{:});
wolffd@0 198 end
wolffd@0 199 end
wolffd@0 200
wolffd@0 201 iter = iter + 1;
wolffd@0 202 old_msg = new_msg;
wolffd@0 203 old_bel = new_bel;
wolffd@0 204 end % while
wolffd@0 205
wolffd@0 206 niter = iter-1;
wolffd@0 207
wolffd@0 208 fprintf('converged in %d iterations\n', niter);
wolffd@0 209