samer@0
|
1 :- module(crp,
|
samer@0
|
2 [ empty_classes/1
|
samer@0
|
3 , classes_value/2
|
samer@0
|
4 , classes_counts/2
|
samer@0
|
5 , classes_update/3
|
samer@0
|
6 , seqmap_classes//2
|
samer@0
|
7 , dec_class//3
|
samer@0
|
8 , inc_class//1
|
samer@0
|
9 , remove_class//1
|
samer@0
|
10 , add_class//2
|
samer@0
|
11
|
samer@0
|
12 , crp_prob/5
|
samer@0
|
13 , crp_sample/5
|
samer@0
|
14 , crp_sample_obs/7
|
samer@0
|
15 , crp_sample_rm/5
|
samer@0
|
16 , crp_dist/6
|
samer@0
|
17
|
samer@0
|
18 , dp_sampler_teh/3
|
samer@0
|
19 , py_sampler_teh/4
|
samer@0
|
20 ]).
|
samer@0
|
21
|
samer@0
|
22 /** <module> Chinese Restaurant Process utilities
|
samer@0
|
23
|
samer@0
|
24 ==
|
samer@0
|
25 gem_model ---> dp(Alpha:nonneg)
|
samer@0
|
26 ; py(Alpha:nonneg,Discount:nonneg).
|
samer@0
|
27
|
samer@0
|
28 gamma_prior ---> gamma(nonneg,nonneg).
|
samer@0
|
29 beta_prior ---> beta(nonneg,nonneg).
|
samer@0
|
30 param_sampler == pred(+gem_model,-gem_model,+rndstate,-rndstate).
|
samer@0
|
31 ==
|
samer@0
|
32
|
samer@0
|
33 */
|
samer@0
|
34 :- meta_predicate seqmap_classes(4,+,?,?).
|
samer@0
|
35
|
samer@0
|
36 :- load_foreign_library(foreign(crp)).
|
samer@0
|
37 :- use_module(library(dcgu)).
|
samer@0
|
38 :- use_module(library(math)).
|
samer@0
|
39 :- use_module(library(eval)).
|
samer@0
|
40 :- use_module(library(lazy)).
|
samer@0
|
41 :- use_module(library(randpred)).
|
samer@0
|
42 :- use_module(library(apply_macros)).
|
samer@0
|
43
|
samer@0
|
44
|
samer@0
|
45 %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det.
|
samer@0
|
46 %
|
samer@0
|
47 % Compute the probability Prob of observing X given a CRP
|
samer@0
|
48 % and a base probability of PProb.
|
samer@0
|
49
|
samer@0
|
50
|
samer@0
|
51 %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det.
|
samer@0
|
52 %
|
samer@0
|
53 % Sample a new value from CRP, Action A is either new, which means
|
samer@0
|
54 % that the user should sample a new value from the base distribtion,
|
samer@0
|
55 % or old(X,ID), where X is an old value and C is the class ID.
|
samer@0
|
56 % Operates in random state DCG.
|
samer@0
|
57
|
samer@0
|
58
|
samer@0
|
59 %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det.
|
samer@0
|
60 %
|
samer@0
|
61 % Sample class appropriate for observation of value X. PProb is the
|
samer@0
|
62 % base probability of X from the base distribution. Action A is new
|
samer@0
|
63 % or old(ID) where ID is the class id.
|
samer@0
|
64 % Operates in random state DCG.
|
samer@0
|
65
|
samer@0
|
66
|
samer@0
|
67 %% crp_sample_rm( +Classes:classes(A), +X:A, -C:class_id)// is det.
|
samer@0
|
68 %
|
samer@0
|
69 % Sample appropriate class from which to remove value X. C is the
|
samer@0
|
70 % class id of the chosen class.
|
samer@0
|
71 % Operates in random state DCG.
|
samer@0
|
72
|
samer@0
|
73
|
samer@0
|
74 %% crp_dist( +GEM:gem_model, +Classes:classes(A), +Base:dist(A), -Dist:dist(A))// is det.
|
samer@0
|
75 %
|
samer@0
|
76 % Get posterior distribution associated with node using stick breaking method.
|
samer@0
|
77 % Operates in random state DCG.
|
samer@0
|
78 crp_dist( dp(Alpha), classes(_,Counts,Values), Base, Dist, RS1, RS3) :-
|
samer@0
|
79 sumlist(Counts,Total),
|
samer@0
|
80 Norm is Total+Alpha,
|
samer@0
|
81
|
samer@0
|
82 ( Total>0
|
samer@0
|
83 -> dirichlet(Counts,Probs1, RS1, RS2),
|
samer@0
|
84 lazy_dp(Alpha,Base,Alpha,ValuesT,ProbsT, RS2, RS3),
|
samer@0
|
85 maplist(mul(Total),Probs1,Probs2),
|
samer@0
|
86 append(Probs2,ProbsT,ProbsA),
|
samer@0
|
87 append(Values,ValuesT,ValuesA),
|
samer@0
|
88 Dist=lazy_discrete(ValuesA,ProbsA,Norm)
|
samer@0
|
89 ; lazy_dp(Alpha, Base, 1, ValuesT, ProbsT, RS1, RS3),
|
samer@0
|
90 Dist=lazy_discrete(ValuesT,ProbsT,1)
|
samer@0
|
91 ).
|
samer@0
|
92
|
samer@0
|
93
|
samer@0
|
94 % --------------------------------------------------------------------------------
|
samer@0
|
95 % classes data structure (basic CRP stuff)
|
samer@0
|
96
|
samer@0
|
97 user:portray(classes(_,Counts,Vals)) :- format('<crp|~p:~p>',[Counts,Vals]).
|
samer@0
|
98
|
samer@0
|
99
|
samer@0
|
100 %% empty_classes( -Classes:classes(_)) is det.
|
samer@0
|
101 %
|
samer@0
|
102 % Unify Classes with an empty classes structure.
|
samer@0
|
103 empty_classes(classes(0,[],[])).
|
samer@0
|
104
|
samer@0
|
105
|
samer@0
|
106 %% classes_value( +Classes:classes(A), +X:A) is semidet.
|
samer@0
|
107 %% classes_value( +Classes:classes(A), -X:A) is multi.
|
samer@0
|
108 %
|
samer@0
|
109 % Check that X is one of the values represented in Classes.
|
samer@0
|
110 % If X is unbound on entry, it is unified with all values on backtracking.
|
samer@0
|
111 classes_value(classes(_,_,Vals),X) :- member(X,Vals).
|
samer@0
|
112
|
samer@0
|
113
|
samer@0
|
114 %% classes_counts( +Classes:classes(A), -Counts:list(natural)) is det.
|
samer@0
|
115 %
|
samer@0
|
116 % Gets the list of counts, one per class.
|
samer@0
|
117 classes_counts( classes(_,Counts,_), Counts).
|
samer@0
|
118
|
samer@0
|
119 %% seqmap_classes( +P:pred(natural,A,T,T), +Classes:classes(A), +S1:T, -S2:T) is multi.
|
samer@0
|
120 %
|
samer@0
|
121 % Sequentiall apply phrase P to all classes. Arguments to P are the number of items
|
samer@0
|
122 % in the class and the value (of type A) associated with it.
|
samer@0
|
123 seqmap_classes(P, classes(_,Counts,Vals)) --> seqmap( P, Counts, Vals).
|
samer@0
|
124
|
samer@0
|
125 user:goal_expansion(seqmap_classes(P,CX,S1,S2), (CX=classes(_,Counts,Vals), seqmap(P, Counts,Vals,S1,S2))).
|
samer@0
|
126
|
samer@0
|
127 %% dec_class( +ID:class_id, -C:natural, -X:A, +C1:classes(A), -C2:classes(A)) is det.
|
samer@0
|
128 %
|
samer@0
|
129 % Decrement count associated with class id N. C is the count after
|
samer@0
|
130 % decrementing and X is the value associated with the Nth class.
|
samer@0
|
131 dec_class(N,CI,X,classes(K,C1,V),classes(K,C2,V)) :- dec_nth(N,_,CI,C1,C2), nth1(N,V,X).
|
samer@0
|
132 dec_nth(1,X,Y,[X|T],[Y|T]) :- succ(Y,X).
|
samer@0
|
133 dec_nth(N,A,B,[X|T1],[X|T2]) :- succ(M,N), dec_nth(M,A,B,T1,T2).
|
samer@0
|
134
|
samer@0
|
135 %% inc_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
|
samer@0
|
136 %
|
samer@0
|
137 % Increment count associated with class N.
|
samer@0
|
138 inc_class(C,classes(K,C1,V),classes(K,C2,V)) :- inc_nth(C,C1,C2).
|
samer@0
|
139 inc_nth(1,[X|T],[Y|T]) :- succ(X,Y).
|
samer@0
|
140 inc_nth(N,[X|T1],[X|T2]) :- succ(M,N), inc_nth(M,T1,T2).
|
samer@0
|
141
|
samer@0
|
142
|
samer@0
|
143 %% remove_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
|
samer@0
|
144 %
|
samer@0
|
145 % Removes class N.
|
samer@0
|
146 remove_class(I,classes(K1,C1,V1),classes(K2,C2,V2)) :-
|
samer@0
|
147 remove_from_list(I,_,C1,C2),
|
samer@0
|
148 remove_from_list(I,_,V1,V2),
|
samer@0
|
149 succ(K2,K1).
|
samer@0
|
150
|
samer@0
|
151 %% add_class( +X:A, -ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
|
samer@0
|
152 %
|
samer@0
|
153 % Add a class associated with value X. N is the id of the new class.
|
samer@0
|
154 add_class(X,K2,classes(K1,C1,V1),classes(K2,C2,V2)) :-
|
samer@0
|
155 succ(K1,K2),
|
samer@0
|
156 append(C1,[1],C2),
|
samer@0
|
157 append(V1,[X],V2).
|
samer@0
|
158
|
samer@0
|
159
|
samer@0
|
160 remove_from_list(1,X,[X|T],T).
|
samer@0
|
161 remove_from_list(N,X,[Y|T1],[Y|T2]) :-
|
samer@0
|
162 ( var(N)
|
samer@0
|
163 -> remove_from_list(M,X,T1,T2), succ(M,N)
|
samer@0
|
164 ; succ(M,N), remove_from_list(M,X,T1,T2)
|
samer@0
|
165 ).
|
samer@0
|
166
|
samer@0
|
167
|
samer@0
|
168 %------------------------------------------------------------------
|
samer@0
|
169 % Get posterior distribution at node using stick-breaking
|
samer@0
|
170 % construction.
|
samer@0
|
171
|
samer@0
|
172 lazy_dp(A,H,P0,Vals,Probs) -->
|
samer@0
|
173 spawn(S0), { lazy_unfold(unfold_dp(A,H),Vals,Probs,(P0,S0),_) }.
|
samer@0
|
174
|
samer@0
|
175 lazy_dp_paired(A,H,P0,ValsProbs) -->
|
samer@0
|
176 spawn(S0), { lazy_unfold(unfold_dp(A,H),ValsProbs,(P0,S0),_) }.
|
samer@0
|
177
|
samer@0
|
178 unfold_dp(A,H,V,X) --> \> call(H,V), unfold_gem(A,X).
|
samer@0
|
179 unfold_dp(A,H,V:X) --> \> call(H,V), unfold_gem(A,X).
|
samer@0
|
180
|
samer@0
|
181 % lazy_gem(A,Probs) --> spawn(S0), { lazy_unfold(unfold_gem(A),Probs,(1,S0),_) }.
|
samer@0
|
182
|
samer@0
|
183 unfold_gem(A,X) -->
|
samer@0
|
184 \> beta(1,A,P),
|
samer@0
|
185 \< trans(P0,P1),
|
samer@0
|
186 { X is P*P0, P1 is P0-X }.
|
samer@0
|
187
|
samer@0
|
188 %% classes_update( +Action:action(A), +C1:classes(A), -C2:classes(A)) is det.
|
samer@0
|
189 %
|
samer@0
|
190 % Update classes structure with a new observation.
|
samer@0
|
191 classes_update(old(_,ID),C1,C2) :- inc_class(ID,C1,C2).
|
samer@0
|
192 classes_update(new(X,ID),C1,C2) :- add_class(X,ID,C1,C2).
|
samer@0
|
193
|
samer@0
|
194
|
samer@0
|
195
|
samer@0
|
196
|
samer@0
|
197 % PARAMETER SAMPLING
|
samer@0
|
198
|
samer@0
|
199
|
samer@0
|
200
|
samer@0
|
201 % ---------------------------------------------------------------
|
samer@0
|
202 % Initialisers
|
samer@0
|
203 % Samplers written in C.
|
samer@0
|
204
|
samer@0
|
205 %% dp_sampler_teh( +Prior:gamma_prior, +Counts:list(natural), -S:param_sampler) is det.
|
samer@0
|
206 %
|
samer@0
|
207 % Prepares a predicate for sampling the concentration parameter of a Dirichlet process.
|
samer@0
|
208 % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=.
|
samer@0
|
209 dp_sampler_teh( gamma(A,B), CX, crp:sample_dp_teh(ApSumKX,B,NX)) :-
|
samer@0
|
210 maplist(sumlist,CX,NX),
|
samer@0
|
211 maplist(length,CX,KX),
|
samer@0
|
212 sumlist(KX,SumKX),
|
samer@0
|
213 ApSumKX is A+SumKX.
|
samer@0
|
214
|
samer@0
|
215 %% py_sampler_teh( +ThPrior:gamma_prior, +DiscPr:beta_prior, +Counts:list(natural), -S:param_sampler) is det.
|
samer@0
|
216 %
|
samer@0
|
217 % Prepares a predicate for sampling the concentration and discount
|
samer@0
|
218 % parameters of a Pitman-Yor process.
|
samer@0
|
219 % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=.
|
samer@0
|
220 py_sampler_teh( ThPrior, DiscPrior, CountsX, crp:Sampler) :-
|
samer@0
|
221 Sampler = sample_py_teh( ThPrior, DiscPrior, CountsX).
|
samer@0
|
222
|
samer@0
|
223 /*
|
samer@0
|
224 slow_sample_py_teh( gamma(A,B), beta(DA,DB), CountsX, py(Theta1,Disc1), py(Theta2,Disc2)) -->
|
samer@0
|
225 % do several lots of sampling auxillary variables, one per client node
|
samer@0
|
226 % seqmap( py_sample_s_z_w(Theta1,Disc1), CountsX, SX, NSX, ZX, WX),
|
samer@0
|
227 seqmap( py_sample_s_z_log_w(Theta1,Disc1), CountsX, SX, NSX, ZX, LogWX),
|
samer@0
|
228 { % maplist(log,WX,LogWX),
|
samer@0
|
229 sumlist(SX,SumSX),
|
samer@0
|
230 sumlist(NSX,SumNSX),
|
samer@0
|
231 sumlist(ZX,SumZX),
|
samer@0
|
232 sumlist(LogWX,SumLogWX),
|
samer@0
|
233 A1 is A+SumSX, B1 is B-SumLogWX,
|
samer@0
|
234 DA1 is DA+SumNSX, DB1 is DB+SumZX },
|
samer@0
|
235 gamma(A1, B1, Theta2),
|
samer@0
|
236 beta(DA1, DB1, Disc2).
|
samer@0
|
237
|
samer@0
|
238 py_sample_s_z_w(Theta,Disc,Counts,S,NS,Z,W) -->
|
samer@0
|
239 py_sample_bern_z(Disc,Counts,Z),
|
samer@0
|
240 py_sample_bern_s(Theta,Disc,Counts,S,NS),
|
samer@0
|
241 py_sample_beta_w(Theta,Counts,W).
|
samer@0
|
242
|
samer@0
|
243 py_sample_s_z_log_w(Theta,Disc,Counts,S,NS,Z,LogW) -->
|
samer@0
|
244 py_sample_bern_z(Disc,Counts,Z),
|
samer@0
|
245 py_sample_bern_s(Theta,Disc,Counts,S,NS),
|
samer@0
|
246 py_sample_beta_log_w(Theta,Counts,LogW).
|
samer@0
|
247
|
samer@0
|
248 py_sample_beta_w(_, [], 1) --> !.
|
samer@0
|
249 py_sample_beta_w(Theta, Counts, W) -->
|
samer@0
|
250 {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1},
|
samer@0
|
251 beta( Th1, N1, W).
|
samer@0
|
252
|
samer@0
|
253 py_sample_beta_log_w(_, [], 0) --> !.
|
samer@0
|
254 py_sample_beta_log_w(Theta, Counts, LogW) -->
|
samer@0
|
255 {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1},
|
samer@0
|
256 beta( Th1, N1, W), { LogW is log(W) }.
|
samer@0
|
257
|
samer@0
|
258 py_sample_bern_s(Theta,Disc,Counts,SumS,SumNS) -->
|
samer@0
|
259 ( {Counts=[_|Cm1], length(Cm1,Kminus1), numlist(1,Kminus1,KX)}
|
samer@0
|
260 -> {maplist(mul(Disc),KX,KDX)},
|
samer@0
|
261 sum_bernoulli(KDX, Theta, SumS),
|
samer@0
|
262 {SumNS is Kminus1 - SumS}
|
samer@0
|
263 ; {SumS=0,SumNS=0}
|
samer@0
|
264 ).
|
samer@0
|
265
|
samer@0
|
266 py_sample_bern_z(Disc,Counts,Z) -->
|
samer@0
|
267 {Disc1 is 1-Disc},
|
samer@0
|
268 seqmap( sample_bern_z(Disc1), Counts, ZX),
|
samer@0
|
269 {sumlist(ZX,Z)}.
|
samer@0
|
270
|
samer@0
|
271 sample_bern_z(Disc1,Count,SumZ) -->
|
samer@0
|
272 {CountM2 is Count-2},
|
samer@0
|
273 ( {CountM2<0} -> {SumZ=0}
|
samer@0
|
274 ; {numlist(0,CountM2,I)},
|
samer@0
|
275 sum_bernoulli(I, Disc1, SumZ)
|
samer@0
|
276 ).
|
samer@0
|
277
|
samer@0
|
278 sum_bernoulli(AX,B,T,S1,S2) :- sum_bernoulli(AX,B,0,T,S1,S2).
|
samer@0
|
279 sum_bernoulli([],_,T,T,S,S) :- !.
|
samer@0
|
280 sum_bernoulli([A|AX],B,T1,T3,S1,S3) :-
|
samer@0
|
281 bernoulli(A,B,X,S1,S2), T2 is T1+X,
|
samer@0
|
282 sum_bernoulli(AX,B,T2,T3,S2,S3).
|
samer@0
|
283
|
samer@0
|
284 % Gamma distribution with rate parameter B.
|
samer@0
|
285 :- procedure gamma(1,1).
|
samer@0
|
286 gamma(A,B,X) --> gamma(A,U), {X is U/B}.
|
samer@0
|
287
|
samer@0
|
288 % Bernoulli with unnormalised weights for 0 and 1.
|
samer@0
|
289 :- procedure bernoulli(1,1).
|
samer@0
|
290 bernoulli(A,B,X) -->
|
samer@0
|
291 uniform01(U),
|
samer@0
|
292 ({(A+B)*U<B} -> {X=1}; {X=0} ).
|
samer@0
|
293 */
|