annotate crp.pl @ 3:974d7be8eec4 tip

Update to pack-based dcg utilities
author samer
date Tue, 03 Oct 2017 11:52:23 +0100
parents 2c8a10d9e3cb
children
rev   line source
samer@0 1 :- module(crp,
samer@0 2 [ empty_classes/1
samer@0 3 , classes_value/2
samer@0 4 , classes_counts/2
samer@0 5 , classes_update/3
samer@0 6 , seqmap_classes//2
samer@0 7 , dec_class//3
samer@0 8 , inc_class//1
samer@0 9 , remove_class//1
samer@0 10 , add_class//2
samer@0 11
samer@0 12 , crp_prob/5
samer@0 13 , crp_sample/5
samer@0 14 , crp_sample_obs/7
samer@0 15 , crp_sample_rm/5
samer@0 16 , crp_dist/6
samer@0 17
samer@0 18 , dp_sampler_teh/3
samer@0 19 , py_sampler_teh/4
samer@0 20 ]).
samer@0 21
samer@0 22 /** <module> Chinese Restaurant Process utilities
samer@0 23
samer@0 24 ==
samer@0 25 gem_model ---> dp(Alpha:nonneg)
samer@0 26 ; py(Alpha:nonneg,Discount:nonneg).
samer@0 27
samer@0 28 gamma_prior ---> gamma(nonneg,nonneg).
samer@0 29 beta_prior ---> beta(nonneg,nonneg).
samer@0 30 param_sampler == pred(+gem_model,-gem_model,+rndstate,-rndstate).
samer@0 31 ==
samer@0 32
samer@0 33 */
samer@0 34 :- meta_predicate seqmap_classes(4,+,?,?).
samer@0 35
samer@2 36 :- use_foreign_library(foreign(crp)).
samer@3 37 :- use_module(library(dcg_core)).
samer@3 38 :- use_module(library(dcg_pair)).
samer@3 39 :- use_module(library(dcg_macros)).
samer@0 40 :- use_module(library(math)).
samer@0 41 :- use_module(library(eval)).
samer@0 42 :- use_module(library(lazy)).
samer@0 43 :- use_module(library(randpred)).
samer@0 44 :- use_module(library(apply_macros)).
samer@0 45
samer@0 46
samer@0 47 %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det.
samer@0 48 %
samer@0 49 % Compute the probability Prob of observing X given a CRP
samer@0 50 % and a base probability of PProb.
samer@0 51
samer@0 52
samer@0 53 %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det.
samer@0 54 %
samer@0 55 % Sample a new value from CRP, Action A is either new, which means
samer@0 56 % that the user should sample a new value from the base distribtion,
samer@0 57 % or old(X,ID), where X is an old value and C is the class ID.
samer@0 58 % Operates in random state DCG.
samer@0 59
samer@0 60
samer@0 61 %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det.
samer@0 62 %
samer@0 63 % Sample class appropriate for observation of value X. PProb is the
samer@0 64 % base probability of X from the base distribution. Action A is new
samer@0 65 % or old(ID) where ID is the class id.
samer@0 66 % Operates in random state DCG.
samer@0 67
samer@0 68
samer@0 69 %% crp_sample_rm( +Classes:classes(A), +X:A, -C:class_id)// is det.
samer@0 70 %
samer@0 71 % Sample appropriate class from which to remove value X. C is the
samer@0 72 % class id of the chosen class.
samer@0 73 % Operates in random state DCG.
samer@0 74
samer@0 75
samer@0 76 %% crp_dist( +GEM:gem_model, +Classes:classes(A), +Base:dist(A), -Dist:dist(A))// is det.
samer@0 77 %
samer@0 78 % Get posterior distribution associated with node using stick breaking method.
samer@0 79 % Operates in random state DCG.
samer@0 80 crp_dist( dp(Alpha), classes(_,Counts,Values), Base, Dist, RS1, RS3) :-
samer@0 81 sumlist(Counts,Total),
samer@0 82 Norm is Total+Alpha,
samer@0 83
samer@0 84 ( Total>0
samer@0 85 -> dirichlet(Counts,Probs1, RS1, RS2),
samer@0 86 lazy_dp(Alpha,Base,Alpha,ValuesT,ProbsT, RS2, RS3),
samer@0 87 maplist(mul(Total),Probs1,Probs2),
samer@0 88 append(Probs2,ProbsT,ProbsA),
samer@0 89 append(Values,ValuesT,ValuesA),
samer@0 90 Dist=lazy_discrete(ValuesA,ProbsA,Norm)
samer@0 91 ; lazy_dp(Alpha, Base, 1, ValuesT, ProbsT, RS1, RS3),
samer@0 92 Dist=lazy_discrete(ValuesT,ProbsT,1)
samer@0 93 ).
samer@0 94
samer@0 95
samer@0 96 % --------------------------------------------------------------------------------
samer@0 97 % classes data structure (basic CRP stuff)
samer@0 98
samer@0 99 user:portray(classes(_,Counts,Vals)) :- format('<crp|~p:~p>',[Counts,Vals]).
samer@0 100
samer@0 101
samer@0 102 %% empty_classes( -Classes:classes(_)) is det.
samer@0 103 %
samer@0 104 % Unify Classes with an empty classes structure.
samer@0 105 empty_classes(classes(0,[],[])).
samer@0 106
samer@0 107
samer@0 108 %% classes_value( +Classes:classes(A), +X:A) is semidet.
samer@0 109 %% classes_value( +Classes:classes(A), -X:A) is multi.
samer@0 110 %
samer@0 111 % Check that X is one of the values represented in Classes.
samer@0 112 % If X is unbound on entry, it is unified with all values on backtracking.
samer@0 113 classes_value(classes(_,_,Vals),X) :- member(X,Vals).
samer@0 114
samer@0 115
samer@0 116 %% classes_counts( +Classes:classes(A), -Counts:list(natural)) is det.
samer@0 117 %
samer@0 118 % Gets the list of counts, one per class.
samer@0 119 classes_counts( classes(_,Counts,_), Counts).
samer@0 120
samer@0 121 %% seqmap_classes( +P:pred(natural,A,T,T), +Classes:classes(A), +S1:T, -S2:T) is multi.
samer@0 122 %
samer@0 123 % Sequentiall apply phrase P to all classes. Arguments to P are the number of items
samer@0 124 % in the class and the value (of type A) associated with it.
samer@0 125 seqmap_classes(P, classes(_,Counts,Vals)) --> seqmap( P, Counts, Vals).
samer@0 126
samer@0 127 user:goal_expansion(seqmap_classes(P,CX,S1,S2), (CX=classes(_,Counts,Vals), seqmap(P, Counts,Vals,S1,S2))).
samer@0 128
samer@0 129 %% dec_class( +ID:class_id, -C:natural, -X:A, +C1:classes(A), -C2:classes(A)) is det.
samer@0 130 %
samer@0 131 % Decrement count associated with class id N. C is the count after
samer@0 132 % decrementing and X is the value associated with the Nth class.
samer@0 133 dec_class(N,CI,X,classes(K,C1,V),classes(K,C2,V)) :- dec_nth(N,_,CI,C1,C2), nth1(N,V,X).
samer@0 134 dec_nth(1,X,Y,[X|T],[Y|T]) :- succ(Y,X).
samer@0 135 dec_nth(N,A,B,[X|T1],[X|T2]) :- succ(M,N), dec_nth(M,A,B,T1,T2).
samer@0 136
samer@0 137 %% inc_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
samer@0 138 %
samer@0 139 % Increment count associated with class N.
samer@0 140 inc_class(C,classes(K,C1,V),classes(K,C2,V)) :- inc_nth(C,C1,C2).
samer@0 141 inc_nth(1,[X|T],[Y|T]) :- succ(X,Y).
samer@0 142 inc_nth(N,[X|T1],[X|T2]) :- succ(M,N), inc_nth(M,T1,T2).
samer@0 143
samer@0 144
samer@0 145 %% remove_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
samer@0 146 %
samer@0 147 % Removes class N.
samer@0 148 remove_class(I,classes(K1,C1,V1),classes(K2,C2,V2)) :-
samer@0 149 remove_from_list(I,_,C1,C2),
samer@0 150 remove_from_list(I,_,V1,V2),
samer@0 151 succ(K2,K1).
samer@0 152
samer@0 153 %% add_class( +X:A, -ID:class_id, +C1:classes(A), -C2:classes(A)) is det.
samer@0 154 %
samer@0 155 % Add a class associated with value X. N is the id of the new class.
samer@0 156 add_class(X,K2,classes(K1,C1,V1),classes(K2,C2,V2)) :-
samer@0 157 succ(K1,K2),
samer@0 158 append(C1,[1],C2),
samer@0 159 append(V1,[X],V2).
samer@0 160
samer@0 161
samer@0 162 remove_from_list(1,X,[X|T],T).
samer@0 163 remove_from_list(N,X,[Y|T1],[Y|T2]) :-
samer@0 164 ( var(N)
samer@0 165 -> remove_from_list(M,X,T1,T2), succ(M,N)
samer@0 166 ; succ(M,N), remove_from_list(M,X,T1,T2)
samer@0 167 ).
samer@0 168
samer@0 169
samer@0 170 %------------------------------------------------------------------
samer@0 171 % Get posterior distribution at node using stick-breaking
samer@0 172 % construction.
samer@0 173
samer@0 174 lazy_dp(A,H,P0,Vals,Probs) -->
samer@3 175 spawn(S0), { lazy_unfold(unfold_dp(A,H),Vals,Probs,P0-S0,_) }.
samer@0 176
samer@0 177 lazy_dp_paired(A,H,P0,ValsProbs) -->
samer@3 178 spawn(S0), { lazy_unfold(unfold_dp(A,H),ValsProbs,P0-S0,_) }.
samer@0 179
samer@0 180 unfold_dp(A,H,V,X) --> \> call(H,V), unfold_gem(A,X).
samer@0 181 unfold_dp(A,H,V:X) --> \> call(H,V), unfold_gem(A,X).
samer@0 182
samer@0 183 % lazy_gem(A,Probs) --> spawn(S0), { lazy_unfold(unfold_gem(A),Probs,(1,S0),_) }.
samer@0 184
samer@0 185 unfold_gem(A,X) -->
samer@0 186 \> beta(1,A,P),
samer@0 187 \< trans(P0,P1),
samer@0 188 { X is P*P0, P1 is P0-X }.
samer@0 189
samer@0 190 %% classes_update( +Action:action(A), +C1:classes(A), -C2:classes(A)) is det.
samer@0 191 %
samer@0 192 % Update classes structure with a new observation.
samer@0 193 classes_update(old(_,ID),C1,C2) :- inc_class(ID,C1,C2).
samer@0 194 classes_update(new(X,ID),C1,C2) :- add_class(X,ID,C1,C2).
samer@0 195
samer@0 196
samer@0 197
samer@0 198
samer@0 199 % PARAMETER SAMPLING
samer@0 200
samer@0 201
samer@0 202
samer@0 203 % ---------------------------------------------------------------
samer@0 204 % Initialisers
samer@0 205 % Samplers written in C.
samer@0 206
samer@0 207 %% dp_sampler_teh( +Prior:gamma_prior, +Counts:list(natural), -S:param_sampler) is det.
samer@0 208 %
samer@0 209 % Prepares a predicate for sampling the concentration parameter of a Dirichlet process.
samer@0 210 % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=.
samer@0 211 dp_sampler_teh( gamma(A,B), CX, crp:sample_dp_teh(ApSumKX,B,NX)) :-
samer@0 212 maplist(sumlist,CX,NX),
samer@0 213 maplist(length,CX,KX),
samer@0 214 sumlist(KX,SumKX),
samer@0 215 ApSumKX is A+SumKX.
samer@0 216
samer@0 217 %% py_sampler_teh( +ThPrior:gamma_prior, +DiscPr:beta_prior, +Counts:list(natural), -S:param_sampler) is det.
samer@0 218 %
samer@0 219 % Prepares a predicate for sampling the concentration and discount
samer@0 220 % parameters of a Pitman-Yor process.
samer@0 221 % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=.
samer@0 222 py_sampler_teh( ThPrior, DiscPrior, CountsX, crp:Sampler) :-
samer@0 223 Sampler = sample_py_teh( ThPrior, DiscPrior, CountsX).
samer@0 224
samer@0 225 /*
samer@0 226 slow_sample_py_teh( gamma(A,B), beta(DA,DB), CountsX, py(Theta1,Disc1), py(Theta2,Disc2)) -->
samer@0 227 % do several lots of sampling auxillary variables, one per client node
samer@0 228 % seqmap( py_sample_s_z_w(Theta1,Disc1), CountsX, SX, NSX, ZX, WX),
samer@0 229 seqmap( py_sample_s_z_log_w(Theta1,Disc1), CountsX, SX, NSX, ZX, LogWX),
samer@0 230 { % maplist(log,WX,LogWX),
samer@0 231 sumlist(SX,SumSX),
samer@0 232 sumlist(NSX,SumNSX),
samer@0 233 sumlist(ZX,SumZX),
samer@0 234 sumlist(LogWX,SumLogWX),
samer@0 235 A1 is A+SumSX, B1 is B-SumLogWX,
samer@0 236 DA1 is DA+SumNSX, DB1 is DB+SumZX },
samer@0 237 gamma(A1, B1, Theta2),
samer@0 238 beta(DA1, DB1, Disc2).
samer@0 239
samer@0 240 py_sample_s_z_w(Theta,Disc,Counts,S,NS,Z,W) -->
samer@0 241 py_sample_bern_z(Disc,Counts,Z),
samer@0 242 py_sample_bern_s(Theta,Disc,Counts,S,NS),
samer@0 243 py_sample_beta_w(Theta,Counts,W).
samer@0 244
samer@0 245 py_sample_s_z_log_w(Theta,Disc,Counts,S,NS,Z,LogW) -->
samer@0 246 py_sample_bern_z(Disc,Counts,Z),
samer@0 247 py_sample_bern_s(Theta,Disc,Counts,S,NS),
samer@0 248 py_sample_beta_log_w(Theta,Counts,LogW).
samer@0 249
samer@0 250 py_sample_beta_w(_, [], 1) --> !.
samer@0 251 py_sample_beta_w(Theta, Counts, W) -->
samer@0 252 {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1},
samer@0 253 beta( Th1, N1, W).
samer@0 254
samer@0 255 py_sample_beta_log_w(_, [], 0) --> !.
samer@0 256 py_sample_beta_log_w(Theta, Counts, LogW) -->
samer@0 257 {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1},
samer@0 258 beta( Th1, N1, W), { LogW is log(W) }.
samer@0 259
samer@0 260 py_sample_bern_s(Theta,Disc,Counts,SumS,SumNS) -->
samer@0 261 ( {Counts=[_|Cm1], length(Cm1,Kminus1), numlist(1,Kminus1,KX)}
samer@0 262 -> {maplist(mul(Disc),KX,KDX)},
samer@0 263 sum_bernoulli(KDX, Theta, SumS),
samer@0 264 {SumNS is Kminus1 - SumS}
samer@0 265 ; {SumS=0,SumNS=0}
samer@0 266 ).
samer@0 267
samer@0 268 py_sample_bern_z(Disc,Counts,Z) -->
samer@0 269 {Disc1 is 1-Disc},
samer@0 270 seqmap( sample_bern_z(Disc1), Counts, ZX),
samer@0 271 {sumlist(ZX,Z)}.
samer@0 272
samer@0 273 sample_bern_z(Disc1,Count,SumZ) -->
samer@0 274 {CountM2 is Count-2},
samer@0 275 ( {CountM2<0} -> {SumZ=0}
samer@0 276 ; {numlist(0,CountM2,I)},
samer@0 277 sum_bernoulli(I, Disc1, SumZ)
samer@0 278 ).
samer@0 279
samer@0 280 sum_bernoulli(AX,B,T,S1,S2) :- sum_bernoulli(AX,B,0,T,S1,S2).
samer@0 281 sum_bernoulli([],_,T,T,S,S) :- !.
samer@0 282 sum_bernoulli([A|AX],B,T1,T3,S1,S3) :-
samer@0 283 bernoulli(A,B,X,S1,S2), T2 is T1+X,
samer@0 284 sum_bernoulli(AX,B,T2,T3,S2,S3).
samer@0 285
samer@0 286 % Gamma distribution with rate parameter B.
samer@0 287 :- procedure gamma(1,1).
samer@0 288 gamma(A,B,X) --> gamma(A,U), {X is U/B}.
samer@0 289
samer@0 290 % Bernoulli with unnormalised weights for 0 and 1.
samer@0 291 :- procedure bernoulli(1,1).
samer@0 292 bernoulli(A,B,X) -->
samer@0 293 uniform01(U),
samer@0 294 ({(A+B)*U<B} -> {X=1}; {X=0} ).
samer@0 295 */