wolffd@0
|
1 function [new_bel, niter, new_msg, edge_id, nstates] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin)
|
wolffd@0
|
2 % BP_MRF2_GENERAL Belief propagation on an MRF with pairwise potentials
|
wolffd@0
|
3 % function [bel, niter] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin)
|
wolffd@0
|
4 %
|
wolffd@0
|
5 % Input:
|
wolffd@0
|
6 % adj_mat(i,j) = 1 iff there is an edge between nodes i and j
|
wolffd@0
|
7 % pot(ki,kj,i,j) or pot{i,j}(ki,kj) = potential on edge between nodes i,j
|
wolffd@0
|
8 % If the potentials on all edges are the same,
|
wolffd@0
|
9 % you can just pass in 1 array, pot(ki,kj)
|
wolffd@0
|
10 % local_evidence(state, node) or local_evidence{i}(k) = Pr(observation at node i | Xi=k)
|
wolffd@0
|
11 %
|
wolffd@0
|
12 % Use cell arrays if the hidden nodes do not all have the same number of values.
|
wolffd@0
|
13 %
|
wolffd@0
|
14 % Output:
|
wolffd@0
|
15 % bel(k,i) or bel{i}(k) = P(Xi=k|evidence)
|
wolffd@0
|
16 % niter contains the number of iterations used
|
wolffd@0
|
17 %
|
wolffd@0
|
18 % [ ... ] = bp_mrf2(..., 'param1',val1, 'param2',val2, ...)
|
wolffd@0
|
19 % allows you to specify optional parameters as name/value pairs.
|
wolffd@0
|
20 % Parameters names are below [default value in brackets]
|
wolffd@0
|
21 %
|
wolffd@0
|
22 % max_iter - max. num. iterations [ 5*nnodes]
|
wolffd@0
|
23 % momentum - weight assigned to old message in convex combination
|
wolffd@0
|
24 % (useful for damping oscillations) - currently ignored i[0]
|
wolffd@0
|
25 % tol - tolerance used to assess convergence [1e-3]
|
wolffd@0
|
26 % maximize - 1 means use max-product, 0 means use sum-product [0]
|
wolffd@0
|
27 % verbose - 1 means print error at every iteration [0]
|
wolffd@0
|
28 %
|
wolffd@0
|
29 % fn - name of function to call at end of every iteration [ [] ]
|
wolffd@0
|
30 % fnargs - we call feval(fn, bel, iter, fnargs{:}) [ [] ]
|
wolffd@0
|
31
|
wolffd@0
|
32 nnodes = length(adj_mat);
|
wolffd@0
|
33
|
wolffd@0
|
34 [max_iter, momentum, tol, maximize, verbose, fn, fnargs] = ...
|
wolffd@0
|
35 process_options(varargin, 'max_iter', 5*nnodes, 'momentum', 0, ...
|
wolffd@0
|
36 'tol', 1e-3, 'maximize', 0, 'verbose', 0, ...
|
wolffd@0
|
37 'fn', [], 'fnargs', []);
|
wolffd@0
|
38
|
wolffd@0
|
39 if iscell(local_evidence)
|
wolffd@0
|
40 use_cell = 1;
|
wolffd@0
|
41 else
|
wolffd@0
|
42 use_cell = 0;
|
wolffd@0
|
43 [nstates nnodes] = size(local_evidence);
|
wolffd@0
|
44 end
|
wolffd@0
|
45
|
wolffd@0
|
46 if iscell(pot)
|
wolffd@0
|
47 tied_pot = 0;
|
wolffd@0
|
48 else
|
wolffd@0
|
49 tied_pot = (ndims(pot)==2);
|
wolffd@0
|
50 end
|
wolffd@0
|
51
|
wolffd@0
|
52
|
wolffd@0
|
53 % give each edge a unique number
|
wolffd@0
|
54 ndx = find(adj_mat);
|
wolffd@0
|
55 nedges = length(ndx);
|
wolffd@0
|
56 edge_id = zeros(1, nnodes*nnodes);
|
wolffd@0
|
57 edge_id(ndx) = 1:nedges;
|
wolffd@0
|
58 edge_id = reshape(edge_id, nnodes, nnodes);
|
wolffd@0
|
59
|
wolffd@0
|
60 % initialise messages
|
wolffd@0
|
61 if use_cell
|
wolffd@0
|
62 prod_of_msgs = cell(1, nnodes);
|
wolffd@0
|
63 old_bel = cell(1, nnodes);
|
wolffd@0
|
64 nstates = zeros(1, nnodes);
|
wolffd@0
|
65 old_msg = cell(1, nedges);
|
wolffd@0
|
66 for i=1:nnodes
|
wolffd@0
|
67 nstates(i) = length(local_evidence{i});
|
wolffd@0
|
68 prod_of_msgs{i} = local_evidence{i};
|
wolffd@0
|
69 old_bel{i} = local_evidence{i};
|
wolffd@0
|
70 end
|
wolffd@0
|
71 for i=1:nnodes
|
wolffd@0
|
72 nbrs = find(adj_mat(:,i));
|
wolffd@0
|
73 for j=nbrs(:)'
|
wolffd@0
|
74 old_msg{edge_id(i,j)} = normalise(ones(nstates(j),1));
|
wolffd@0
|
75 end
|
wolffd@0
|
76 end
|
wolffd@0
|
77 else
|
wolffd@0
|
78 prod_of_msgs = local_evidence;
|
wolffd@0
|
79 old_bel = local_evidence;
|
wolffd@0
|
80 %old_msg = zeros(nstates, nnodes, nnodes);
|
wolffd@0
|
81 old_msg = zeros(nstates, nedges);
|
wolffd@0
|
82 m = normalise(ones(nstates,1));
|
wolffd@0
|
83 for i=1:nnodes
|
wolffd@0
|
84 nbrs = find(adj_mat(:,i));
|
wolffd@0
|
85 for j=nbrs(:)'
|
wolffd@0
|
86 old_msg(:, edge_id(i,j)) = m;
|
wolffd@0
|
87 %old_msg(:,i,j) = m;
|
wolffd@0
|
88 end
|
wolffd@0
|
89 end
|
wolffd@0
|
90 end
|
wolffd@0
|
91
|
wolffd@0
|
92
|
wolffd@0
|
93 converged = 0;
|
wolffd@0
|
94 iter = 1;
|
wolffd@0
|
95
|
wolffd@0
|
96 while ~converged & (iter <= max_iter)
|
wolffd@0
|
97
|
wolffd@0
|
98 % each node sends a msg to each of its neighbors
|
wolffd@0
|
99 for i=1:nnodes
|
wolffd@0
|
100 nbrs = find(adj_mat(i,:));
|
wolffd@0
|
101 for j=nbrs(:)'
|
wolffd@0
|
102 if tied_pot
|
wolffd@0
|
103 pot_ij = pot;
|
wolffd@0
|
104 else
|
wolffd@0
|
105 if iscell(pot)
|
wolffd@0
|
106 pot_ij = pot{i,j};
|
wolffd@0
|
107 else
|
wolffd@0
|
108 pot_ij = pot(:,:,i,j);
|
wolffd@0
|
109 end
|
wolffd@0
|
110 end
|
wolffd@0
|
111 pot_ij = pot_ij'; % now pot_ij(xj, xi)
|
wolffd@0
|
112 % so pot_ij * msg(xi) = sum_xi pot(xj,xi) msg(xi) = f(xj)
|
wolffd@0
|
113
|
wolffd@0
|
114 if 1
|
wolffd@0
|
115 % Compute temp = product of all incoming msgs except from j
|
wolffd@0
|
116 % by dividing out old msg from j from the product of all msgs sent to i
|
wolffd@0
|
117 if use_cell
|
wolffd@0
|
118 temp = prod_of_msgs{i};
|
wolffd@0
|
119 m = old_msg{edge_id(j,i)};
|
wolffd@0
|
120 else
|
wolffd@0
|
121 temp = prod_of_msgs(:,i);
|
wolffd@0
|
122 m = old_msg(:, edge_id(j,i));
|
wolffd@0
|
123 end
|
wolffd@0
|
124 if any(m==0)
|
wolffd@0
|
125 fprintf('iter=%d, send from i=%d to j=%d\n', iter, i, j);
|
wolffd@0
|
126 keyboard
|
wolffd@0
|
127 end
|
wolffd@0
|
128 m = m + (m==0); % valid since m(k)=0 => temp(k)=0, so can replace 0's with anything
|
wolffd@0
|
129 temp = temp ./ m;
|
wolffd@0
|
130 temp_div = temp;
|
wolffd@0
|
131 end
|
wolffd@0
|
132
|
wolffd@0
|
133 if 1
|
wolffd@0
|
134 % Compute temp = product of all incoming msgs except from j in obvious way
|
wolffd@0
|
135 if use_cell
|
wolffd@0
|
136 %temp = ones(nstates(i),1);
|
wolffd@0
|
137 temp = local_evidence{i};
|
wolffd@0
|
138 for k=nbrs(:)'
|
wolffd@0
|
139 if k==j, continue, end;
|
wolffd@0
|
140 temp = temp .* old_msg{edge_id(k,i)};
|
wolffd@0
|
141 end
|
wolffd@0
|
142 else
|
wolffd@0
|
143 %temp = ones(nstates,1);
|
wolffd@0
|
144 temp = local_evidence(:,i);
|
wolffd@0
|
145 for k=nbrs(:)'
|
wolffd@0
|
146 if k==j, continue, end;
|
wolffd@0
|
147 temp = temp .* old_msg(:, edge_id(k,i));
|
wolffd@0
|
148 end
|
wolffd@0
|
149 end
|
wolffd@0
|
150 end
|
wolffd@0
|
151 %assert(approxeq(temp, temp_div))
|
wolffd@0
|
152 assert(approxeq(normalise(pot_ij * temp), normalise(pot_ij * temp_div)))
|
wolffd@0
|
153
|
wolffd@0
|
154 if maximize
|
wolffd@0
|
155 newm = max_mult(pot_ij, temp); % bottleneck
|
wolffd@0
|
156 else
|
wolffd@0
|
157 newm = pot_ij * temp;
|
wolffd@0
|
158 end
|
wolffd@0
|
159 newm = normalise(newm);
|
wolffd@0
|
160 if use_cell
|
wolffd@0
|
161 new_msg{edge_id(i,j)} = newm;
|
wolffd@0
|
162 else
|
wolffd@0
|
163 new_msg(:, edge_id(i,j)) = newm;
|
wolffd@0
|
164 end
|
wolffd@0
|
165 end % for j
|
wolffd@0
|
166 end % for i
|
wolffd@0
|
167 old_prod_of_msgs = prod_of_msgs;
|
wolffd@0
|
168
|
wolffd@0
|
169 % each node multiplies all its incoming msgs and computes its local belief
|
wolffd@0
|
170 if use_cell
|
wolffd@0
|
171 for i=1:nnodes
|
wolffd@0
|
172 nbrs = find(adj_mat(:,i));
|
wolffd@0
|
173 prod_of_msgs{i} = local_evidence{i};
|
wolffd@0
|
174 for j=nbrs(:)'
|
wolffd@0
|
175 prod_of_msgs{i} = prod_of_msgs{i} .* new_msg{edge_id(j,i)};
|
wolffd@0
|
176 end
|
wolffd@0
|
177 new_bel{i} = normalise(prod_of_msgs{i});
|
wolffd@0
|
178 end
|
wolffd@0
|
179 err = abs(cat(1,new_bel{:}) - cat(1, old_bel{:}));
|
wolffd@0
|
180 else
|
wolffd@0
|
181 for i=1:nnodes
|
wolffd@0
|
182 nbrs = find(adj_mat(:,i));
|
wolffd@0
|
183 prod_of_msgs(:,i) = local_evidence(:,i);
|
wolffd@0
|
184 for j=nbrs(:)'
|
wolffd@0
|
185 prod_of_msgs(:,i) = prod_of_msgs(:,i) .* new_msg(:,edge_id(j,i));
|
wolffd@0
|
186 end
|
wolffd@0
|
187 new_bel(:,i) = normalise(prod_of_msgs(:,i));
|
wolffd@0
|
188 end
|
wolffd@0
|
189 err = abs(new_bel(:) - old_bel(:));
|
wolffd@0
|
190 end
|
wolffd@0
|
191 converged = all(err < tol);
|
wolffd@0
|
192 if verbose, fprintf('error at iter %d = %f\n', iter, sum(err)); end
|
wolffd@0
|
193 if ~isempty(fn)
|
wolffd@0
|
194 if isempty(fnargs)
|
wolffd@0
|
195 feval(fn, new_bel);
|
wolffd@0
|
196 else
|
wolffd@0
|
197 feval(fn, new_bel, iter, fnargs{:});
|
wolffd@0
|
198 end
|
wolffd@0
|
199 end
|
wolffd@0
|
200
|
wolffd@0
|
201 iter = iter + 1;
|
wolffd@0
|
202 old_msg = new_msg;
|
wolffd@0
|
203 old_bel = new_bel;
|
wolffd@0
|
204 end % while
|
wolffd@0
|
205
|
wolffd@0
|
206 niter = iter-1;
|
wolffd@0
|
207
|
wolffd@0
|
208 fprintf('converged in %d iterations\n', niter);
|
wolffd@0
|
209
|