Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@belprop_mrf2_inf_engine/bp_mrf2.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [new_bel, niter, new_msg, edge_id, nstates] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) | |
2 % BP_MRF2_GENERAL Belief propagation on an MRF with pairwise potentials | |
3 % function [bel, niter] = bp_mrf2_general(adj_mat, pot, local_evidence, varargin) | |
4 % | |
5 % Input: | |
6 % adj_mat(i,j) = 1 iff there is an edge between nodes i and j | |
7 % pot(ki,kj,i,j) or pot{i,j}(ki,kj) = potential on edge between nodes i,j | |
8 % If the potentials on all edges are the same, | |
9 % you can just pass in 1 array, pot(ki,kj) | |
10 % local_evidence(state, node) or local_evidence{i}(k) = Pr(observation at node i | Xi=k) | |
11 % | |
12 % Use cell arrays if the hidden nodes do not all have the same number of values. | |
13 % | |
14 % Output: | |
15 % bel(k,i) or bel{i}(k) = P(Xi=k|evidence) | |
16 % niter contains the number of iterations used | |
17 % | |
18 % [ ... ] = bp_mrf2(..., 'param1',val1, 'param2',val2, ...) | |
19 % allows you to specify optional parameters as name/value pairs. | |
20 % Parameters names are below [default value in brackets] | |
21 % | |
22 % max_iter - max. num. iterations [ 5*nnodes] | |
23 % momentum - weight assigned to old message in convex combination | |
24 % (useful for damping oscillations) - currently ignored i[0] | |
25 % tol - tolerance used to assess convergence [1e-3] | |
26 % maximize - 1 means use max-product, 0 means use sum-product [0] | |
27 % verbose - 1 means print error at every iteration [0] | |
28 % | |
29 % fn - name of function to call at end of every iteration [ [] ] | |
30 % fnargs - we call feval(fn, bel, iter, fnargs{:}) [ [] ] | |
31 | |
32 nnodes = length(adj_mat); | |
33 | |
34 [max_iter, momentum, tol, maximize, verbose, fn, fnargs] = ... | |
35 process_options(varargin, 'max_iter', 5*nnodes, 'momentum', 0, ... | |
36 'tol', 1e-3, 'maximize', 0, 'verbose', 0, ... | |
37 'fn', [], 'fnargs', []); | |
38 | |
39 if iscell(local_evidence) | |
40 use_cell = 1; | |
41 else | |
42 use_cell = 0; | |
43 [nstates nnodes] = size(local_evidence); | |
44 end | |
45 | |
46 if iscell(pot) | |
47 tied_pot = 0; | |
48 else | |
49 tied_pot = (ndims(pot)==2); | |
50 end | |
51 | |
52 | |
53 % give each edge a unique number | |
54 ndx = find(adj_mat); | |
55 nedges = length(ndx); | |
56 edge_id = zeros(1, nnodes*nnodes); | |
57 edge_id(ndx) = 1:nedges; | |
58 edge_id = reshape(edge_id, nnodes, nnodes); | |
59 | |
60 % initialise messages | |
61 if use_cell | |
62 prod_of_msgs = cell(1, nnodes); | |
63 old_bel = cell(1, nnodes); | |
64 nstates = zeros(1, nnodes); | |
65 old_msg = cell(1, nedges); | |
66 for i=1:nnodes | |
67 nstates(i) = length(local_evidence{i}); | |
68 prod_of_msgs{i} = local_evidence{i}; | |
69 old_bel{i} = local_evidence{i}; | |
70 end | |
71 for i=1:nnodes | |
72 nbrs = find(adj_mat(:,i)); | |
73 for j=nbrs(:)' | |
74 old_msg{edge_id(i,j)} = normalise(ones(nstates(j),1)); | |
75 end | |
76 end | |
77 else | |
78 prod_of_msgs = local_evidence; | |
79 old_bel = local_evidence; | |
80 %old_msg = zeros(nstates, nnodes, nnodes); | |
81 old_msg = zeros(nstates, nedges); | |
82 m = normalise(ones(nstates,1)); | |
83 for i=1:nnodes | |
84 nbrs = find(adj_mat(:,i)); | |
85 for j=nbrs(:)' | |
86 old_msg(:, edge_id(i,j)) = m; | |
87 %old_msg(:,i,j) = m; | |
88 end | |
89 end | |
90 end | |
91 | |
92 | |
93 converged = 0; | |
94 iter = 1; | |
95 | |
96 while ~converged & (iter <= max_iter) | |
97 | |
98 % each node sends a msg to each of its neighbors | |
99 for i=1:nnodes | |
100 nbrs = find(adj_mat(i,:)); | |
101 for j=nbrs(:)' | |
102 if tied_pot | |
103 pot_ij = pot; | |
104 else | |
105 if iscell(pot) | |
106 pot_ij = pot{i,j}; | |
107 else | |
108 pot_ij = pot(:,:,i,j); | |
109 end | |
110 end | |
111 pot_ij = pot_ij'; % now pot_ij(xj, xi) | |
112 % so pot_ij * msg(xi) = sum_xi pot(xj,xi) msg(xi) = f(xj) | |
113 | |
114 if 1 | |
115 % Compute temp = product of all incoming msgs except from j | |
116 % by dividing out old msg from j from the product of all msgs sent to i | |
117 if use_cell | |
118 temp = prod_of_msgs{i}; | |
119 m = old_msg{edge_id(j,i)}; | |
120 else | |
121 temp = prod_of_msgs(:,i); | |
122 m = old_msg(:, edge_id(j,i)); | |
123 end | |
124 if any(m==0) | |
125 fprintf('iter=%d, send from i=%d to j=%d\n', iter, i, j); | |
126 keyboard | |
127 end | |
128 m = m + (m==0); % valid since m(k)=0 => temp(k)=0, so can replace 0's with anything | |
129 temp = temp ./ m; | |
130 temp_div = temp; | |
131 end | |
132 | |
133 if 1 | |
134 % Compute temp = product of all incoming msgs except from j in obvious way | |
135 if use_cell | |
136 %temp = ones(nstates(i),1); | |
137 temp = local_evidence{i}; | |
138 for k=nbrs(:)' | |
139 if k==j, continue, end; | |
140 temp = temp .* old_msg{edge_id(k,i)}; | |
141 end | |
142 else | |
143 %temp = ones(nstates,1); | |
144 temp = local_evidence(:,i); | |
145 for k=nbrs(:)' | |
146 if k==j, continue, end; | |
147 temp = temp .* old_msg(:, edge_id(k,i)); | |
148 end | |
149 end | |
150 end | |
151 %assert(approxeq(temp, temp_div)) | |
152 assert(approxeq(normalise(pot_ij * temp), normalise(pot_ij * temp_div))) | |
153 | |
154 if maximize | |
155 newm = max_mult(pot_ij, temp); % bottleneck | |
156 else | |
157 newm = pot_ij * temp; | |
158 end | |
159 newm = normalise(newm); | |
160 if use_cell | |
161 new_msg{edge_id(i,j)} = newm; | |
162 else | |
163 new_msg(:, edge_id(i,j)) = newm; | |
164 end | |
165 end % for j | |
166 end % for i | |
167 old_prod_of_msgs = prod_of_msgs; | |
168 | |
169 % each node multiplies all its incoming msgs and computes its local belief | |
170 if use_cell | |
171 for i=1:nnodes | |
172 nbrs = find(adj_mat(:,i)); | |
173 prod_of_msgs{i} = local_evidence{i}; | |
174 for j=nbrs(:)' | |
175 prod_of_msgs{i} = prod_of_msgs{i} .* new_msg{edge_id(j,i)}; | |
176 end | |
177 new_bel{i} = normalise(prod_of_msgs{i}); | |
178 end | |
179 err = abs(cat(1,new_bel{:}) - cat(1, old_bel{:})); | |
180 else | |
181 for i=1:nnodes | |
182 nbrs = find(adj_mat(:,i)); | |
183 prod_of_msgs(:,i) = local_evidence(:,i); | |
184 for j=nbrs(:)' | |
185 prod_of_msgs(:,i) = prod_of_msgs(:,i) .* new_msg(:,edge_id(j,i)); | |
186 end | |
187 new_bel(:,i) = normalise(prod_of_msgs(:,i)); | |
188 end | |
189 err = abs(new_bel(:) - old_bel(:)); | |
190 end | |
191 converged = all(err < tol); | |
192 if verbose, fprintf('error at iter %d = %f\n', iter, sum(err)); end | |
193 if ~isempty(fn) | |
194 if isempty(fnargs) | |
195 feval(fn, new_bel); | |
196 else | |
197 feval(fn, new_bel, iter, fnargs{:}); | |
198 end | |
199 end | |
200 | |
201 iter = iter + 1; | |
202 old_msg = new_msg; | |
203 old_bel = new_bel; | |
204 end % while | |
205 | |
206 niter = iter-1; | |
207 | |
208 fprintf('converged in %d iterations\n', niter); | |
209 |