samer@0: :- module(crp, samer@0: [ empty_classes/1 samer@0: , classes_value/2 samer@0: , classes_counts/2 samer@0: , classes_update/3 samer@0: , seqmap_classes//2 samer@0: , dec_class//3 samer@0: , inc_class//1 samer@0: , remove_class//1 samer@0: , add_class//2 samer@0: samer@0: , crp_prob/5 samer@0: , crp_sample/5 samer@0: , crp_sample_obs/7 samer@0: , crp_sample_rm/5 samer@0: , crp_dist/6 samer@0: samer@0: , dp_sampler_teh/3 samer@0: , py_sampler_teh/4 samer@0: ]). samer@0: samer@0: /** Chinese Restaurant Process utilities samer@0: samer@0: == samer@0: gem_model ---> dp(Alpha:nonneg) samer@0: ; py(Alpha:nonneg,Discount:nonneg). samer@0: samer@0: gamma_prior ---> gamma(nonneg,nonneg). samer@0: beta_prior ---> beta(nonneg,nonneg). samer@0: param_sampler == pred(+gem_model,-gem_model,+rndstate,-rndstate). samer@0: == samer@0: samer@0: */ samer@0: :- meta_predicate seqmap_classes(4,+,?,?). samer@0: samer@0: :- load_foreign_library(foreign(crp)). samer@0: :- use_module(library(dcgu)). samer@0: :- use_module(library(math)). samer@0: :- use_module(library(eval)). samer@0: :- use_module(library(lazy)). samer@0: :- use_module(library(randpred)). samer@0: :- use_module(library(apply_macros)). samer@0: samer@0: samer@0: %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det. samer@0: % samer@0: % Compute the probability Prob of observing X given a CRP samer@0: % and a base probability of PProb. samer@0: samer@0: samer@0: %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det. samer@0: % samer@0: % Sample a new value from CRP, Action A is either new, which means samer@0: % that the user should sample a new value from the base distribtion, samer@0: % or old(X,ID), where X is an old value and C is the class ID. samer@0: % Operates in random state DCG. samer@0: samer@0: samer@0: %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det. samer@0: % samer@0: % Sample class appropriate for observation of value X. PProb is the samer@0: % base probability of X from the base distribution. Action A is new samer@0: % or old(ID) where ID is the class id. samer@0: % Operates in random state DCG. samer@0: samer@0: samer@0: %% crp_sample_rm( +Classes:classes(A), +X:A, -C:class_id)// is det. samer@0: % samer@0: % Sample appropriate class from which to remove value X. C is the samer@0: % class id of the chosen class. samer@0: % Operates in random state DCG. samer@0: samer@0: samer@0: %% crp_dist( +GEM:gem_model, +Classes:classes(A), +Base:dist(A), -Dist:dist(A))// is det. samer@0: % samer@0: % Get posterior distribution associated with node using stick breaking method. samer@0: % Operates in random state DCG. samer@0: crp_dist( dp(Alpha), classes(_,Counts,Values), Base, Dist, RS1, RS3) :- samer@0: sumlist(Counts,Total), samer@0: Norm is Total+Alpha, samer@0: samer@0: ( Total>0 samer@0: -> dirichlet(Counts,Probs1, RS1, RS2), samer@0: lazy_dp(Alpha,Base,Alpha,ValuesT,ProbsT, RS2, RS3), samer@0: maplist(mul(Total),Probs1,Probs2), samer@0: append(Probs2,ProbsT,ProbsA), samer@0: append(Values,ValuesT,ValuesA), samer@0: Dist=lazy_discrete(ValuesA,ProbsA,Norm) samer@0: ; lazy_dp(Alpha, Base, 1, ValuesT, ProbsT, RS1, RS3), samer@0: Dist=lazy_discrete(ValuesT,ProbsT,1) samer@0: ). samer@0: samer@0: samer@0: % -------------------------------------------------------------------------------- samer@0: % classes data structure (basic CRP stuff) samer@0: samer@0: user:portray(classes(_,Counts,Vals)) :- format('',[Counts,Vals]). samer@0: samer@0: samer@0: %% empty_classes( -Classes:classes(_)) is det. samer@0: % samer@0: % Unify Classes with an empty classes structure. samer@0: empty_classes(classes(0,[],[])). samer@0: samer@0: samer@0: %% classes_value( +Classes:classes(A), +X:A) is semidet. samer@0: %% classes_value( +Classes:classes(A), -X:A) is multi. samer@0: % samer@0: % Check that X is one of the values represented in Classes. samer@0: % If X is unbound on entry, it is unified with all values on backtracking. samer@0: classes_value(classes(_,_,Vals),X) :- member(X,Vals). samer@0: samer@0: samer@0: %% classes_counts( +Classes:classes(A), -Counts:list(natural)) is det. samer@0: % samer@0: % Gets the list of counts, one per class. samer@0: classes_counts( classes(_,Counts,_), Counts). samer@0: samer@0: %% seqmap_classes( +P:pred(natural,A,T,T), +Classes:classes(A), +S1:T, -S2:T) is multi. samer@0: % samer@0: % Sequentiall apply phrase P to all classes. Arguments to P are the number of items samer@0: % in the class and the value (of type A) associated with it. samer@0: seqmap_classes(P, classes(_,Counts,Vals)) --> seqmap( P, Counts, Vals). samer@0: samer@0: user:goal_expansion(seqmap_classes(P,CX,S1,S2), (CX=classes(_,Counts,Vals), seqmap(P, Counts,Vals,S1,S2))). samer@0: samer@0: %% dec_class( +ID:class_id, -C:natural, -X:A, +C1:classes(A), -C2:classes(A)) is det. samer@0: % samer@0: % Decrement count associated with class id N. C is the count after samer@0: % decrementing and X is the value associated with the Nth class. samer@0: dec_class(N,CI,X,classes(K,C1,V),classes(K,C2,V)) :- dec_nth(N,_,CI,C1,C2), nth1(N,V,X). samer@0: dec_nth(1,X,Y,[X|T],[Y|T]) :- succ(Y,X). samer@0: dec_nth(N,A,B,[X|T1],[X|T2]) :- succ(M,N), dec_nth(M,A,B,T1,T2). samer@0: samer@0: %% inc_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det. samer@0: % samer@0: % Increment count associated with class N. samer@0: inc_class(C,classes(K,C1,V),classes(K,C2,V)) :- inc_nth(C,C1,C2). samer@0: inc_nth(1,[X|T],[Y|T]) :- succ(X,Y). samer@0: inc_nth(N,[X|T1],[X|T2]) :- succ(M,N), inc_nth(M,T1,T2). samer@0: samer@0: samer@0: %% remove_class( +ID:class_id, +C1:classes(A), -C2:classes(A)) is det. samer@0: % samer@0: % Removes class N. samer@0: remove_class(I,classes(K1,C1,V1),classes(K2,C2,V2)) :- samer@0: remove_from_list(I,_,C1,C2), samer@0: remove_from_list(I,_,V1,V2), samer@0: succ(K2,K1). samer@0: samer@0: %% add_class( +X:A, -ID:class_id, +C1:classes(A), -C2:classes(A)) is det. samer@0: % samer@0: % Add a class associated with value X. N is the id of the new class. samer@0: add_class(X,K2,classes(K1,C1,V1),classes(K2,C2,V2)) :- samer@0: succ(K1,K2), samer@0: append(C1,[1],C2), samer@0: append(V1,[X],V2). samer@0: samer@0: samer@0: remove_from_list(1,X,[X|T],T). samer@0: remove_from_list(N,X,[Y|T1],[Y|T2]) :- samer@0: ( var(N) samer@0: -> remove_from_list(M,X,T1,T2), succ(M,N) samer@0: ; succ(M,N), remove_from_list(M,X,T1,T2) samer@0: ). samer@0: samer@0: samer@0: %------------------------------------------------------------------ samer@0: % Get posterior distribution at node using stick-breaking samer@0: % construction. samer@0: samer@0: lazy_dp(A,H,P0,Vals,Probs) --> samer@0: spawn(S0), { lazy_unfold(unfold_dp(A,H),Vals,Probs,(P0,S0),_) }. samer@0: samer@0: lazy_dp_paired(A,H,P0,ValsProbs) --> samer@0: spawn(S0), { lazy_unfold(unfold_dp(A,H),ValsProbs,(P0,S0),_) }. samer@0: samer@0: unfold_dp(A,H,V,X) --> \> call(H,V), unfold_gem(A,X). samer@0: unfold_dp(A,H,V:X) --> \> call(H,V), unfold_gem(A,X). samer@0: samer@0: % lazy_gem(A,Probs) --> spawn(S0), { lazy_unfold(unfold_gem(A),Probs,(1,S0),_) }. samer@0: samer@0: unfold_gem(A,X) --> samer@0: \> beta(1,A,P), samer@0: \< trans(P0,P1), samer@0: { X is P*P0, P1 is P0-X }. samer@0: samer@0: %% classes_update( +Action:action(A), +C1:classes(A), -C2:classes(A)) is det. samer@0: % samer@0: % Update classes structure with a new observation. samer@0: classes_update(old(_,ID),C1,C2) :- inc_class(ID,C1,C2). samer@0: classes_update(new(X,ID),C1,C2) :- add_class(X,ID,C1,C2). samer@0: samer@0: samer@0: samer@0: samer@0: % PARAMETER SAMPLING samer@0: samer@0: samer@0: samer@0: % --------------------------------------------------------------- samer@0: % Initialisers samer@0: % Samplers written in C. samer@0: samer@0: %% dp_sampler_teh( +Prior:gamma_prior, +Counts:list(natural), -S:param_sampler) is det. samer@0: % samer@0: % Prepares a predicate for sampling the concentration parameter of a Dirichlet process. samer@0: % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=. samer@0: dp_sampler_teh( gamma(A,B), CX, crp:sample_dp_teh(ApSumKX,B,NX)) :- samer@0: maplist(sumlist,CX,NX), samer@0: maplist(length,CX,KX), samer@0: sumlist(KX,SumKX), samer@0: ApSumKX is A+SumKX. samer@0: samer@0: %% py_sampler_teh( +ThPrior:gamma_prior, +DiscPr:beta_prior, +Counts:list(natural), -S:param_sampler) is det. samer@0: % samer@0: % Prepares a predicate for sampling the concentration and discount samer@0: % parameters of a Pitman-Yor process. samer@0: % The sampler's =|gem_prior|= arguments must be of the form =|dp(_)|=. samer@0: py_sampler_teh( ThPrior, DiscPrior, CountsX, crp:Sampler) :- samer@0: Sampler = sample_py_teh( ThPrior, DiscPrior, CountsX). samer@0: samer@0: /* samer@0: slow_sample_py_teh( gamma(A,B), beta(DA,DB), CountsX, py(Theta1,Disc1), py(Theta2,Disc2)) --> samer@0: % do several lots of sampling auxillary variables, one per client node samer@0: % seqmap( py_sample_s_z_w(Theta1,Disc1), CountsX, SX, NSX, ZX, WX), samer@0: seqmap( py_sample_s_z_log_w(Theta1,Disc1), CountsX, SX, NSX, ZX, LogWX), samer@0: { % maplist(log,WX,LogWX), samer@0: sumlist(SX,SumSX), samer@0: sumlist(NSX,SumNSX), samer@0: sumlist(ZX,SumZX), samer@0: sumlist(LogWX,SumLogWX), samer@0: A1 is A+SumSX, B1 is B-SumLogWX, samer@0: DA1 is DA+SumNSX, DB1 is DB+SumZX }, samer@0: gamma(A1, B1, Theta2), samer@0: beta(DA1, DB1, Disc2). samer@0: samer@0: py_sample_s_z_w(Theta,Disc,Counts,S,NS,Z,W) --> samer@0: py_sample_bern_z(Disc,Counts,Z), samer@0: py_sample_bern_s(Theta,Disc,Counts,S,NS), samer@0: py_sample_beta_w(Theta,Counts,W). samer@0: samer@0: py_sample_s_z_log_w(Theta,Disc,Counts,S,NS,Z,LogW) --> samer@0: py_sample_bern_z(Disc,Counts,Z), samer@0: py_sample_bern_s(Theta,Disc,Counts,S,NS), samer@0: py_sample_beta_log_w(Theta,Counts,LogW). samer@0: samer@0: py_sample_beta_w(_, [], 1) --> !. samer@0: py_sample_beta_w(Theta, Counts, W) --> samer@0: {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1}, samer@0: beta( Th1, N1, W). samer@0: samer@0: py_sample_beta_log_w(_, [], 0) --> !. samer@0: py_sample_beta_log_w(Theta, Counts, LogW) --> samer@0: {sumlist(Counts,N), Th1 is Theta+1, N1 is N-1}, samer@0: beta( Th1, N1, W), { LogW is log(W) }. samer@0: samer@0: py_sample_bern_s(Theta,Disc,Counts,SumS,SumNS) --> samer@0: ( {Counts=[_|Cm1], length(Cm1,Kminus1), numlist(1,Kminus1,KX)} samer@0: -> {maplist(mul(Disc),KX,KDX)}, samer@0: sum_bernoulli(KDX, Theta, SumS), samer@0: {SumNS is Kminus1 - SumS} samer@0: ; {SumS=0,SumNS=0} samer@0: ). samer@0: samer@0: py_sample_bern_z(Disc,Counts,Z) --> samer@0: {Disc1 is 1-Disc}, samer@0: seqmap( sample_bern_z(Disc1), Counts, ZX), samer@0: {sumlist(ZX,Z)}. samer@0: samer@0: sample_bern_z(Disc1,Count,SumZ) --> samer@0: {CountM2 is Count-2}, samer@0: ( {CountM2<0} -> {SumZ=0} samer@0: ; {numlist(0,CountM2,I)}, samer@0: sum_bernoulli(I, Disc1, SumZ) samer@0: ). samer@0: samer@0: sum_bernoulli(AX,B,T,S1,S2) :- sum_bernoulli(AX,B,0,T,S1,S2). samer@0: sum_bernoulli([],_,T,T,S,S) :- !. samer@0: sum_bernoulli([A|AX],B,T1,T3,S1,S3) :- samer@0: bernoulli(A,B,X,S1,S2), T2 is T1+X, samer@0: sum_bernoulli(AX,B,T2,T3,S2,S3). samer@0: samer@0: % Gamma distribution with rate parameter B. samer@0: :- procedure gamma(1,1). samer@0: gamma(A,B,X) --> gamma(A,U), {X is U/B}. samer@0: samer@0: % Bernoulli with unnormalised weights for 0 and 1. samer@0: :- procedure bernoulli(1,1). samer@0: bernoulli(A,B,X) --> samer@0: uniform01(U), samer@0: ({(A+B)*U {X=1}; {X=0} ). samer@0: */