Mercurial > hg > plcrp
view crp.c @ 3:974d7be8eec4 tip
Update to pack-based dcg utilities
author | samer |
---|---|
date | Tue, 03 Oct 2017 11:52:23 +0100 |
parents | b31415b4a196 |
children |
line wrap: on
line source
/* */ #include <SWI-Prolog.h> #include <math.h> #include <float.h> #include <stdio.h> #include "rndutils.h" #include "plutils.c" install_t install(); foreign_t crp_prob( term_t alpha, term_t classes, term_t x, term_t pprob, term_t p); foreign_t crp_sample( term_t alpha, term_t classes, term_t action, term_t rnd1, term_t rnd2); foreign_t crp_sample_obs( term_t alpha, term_t classes, term_t x, term_t probx, term_t act, term_t rnd1, term_t rnd2); foreign_t crp_sample_rm( term_t classes, term_t x, term_t class, term_t rnd1, term_t rnd2); foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2); foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2); static atom_t atom_new; static functor_t functor_old1, functor_old2; static functor_t functor_dp1, functor_py2; install_t install() { PL_register_foreign("crp_prob", 5, (void *)crp_prob, 0); PL_register_foreign("crp_sample", 5, (void *)crp_sample, 0); PL_register_foreign("crp_sample_obs", 7, (void *)crp_sample_obs, 0); PL_register_foreign("crp_sample_rm", 5, (void *)crp_sample_rm, 0); PL_register_foreign("sample_dp_teh", 7, (void *)sample_dp_teh, 0); PL_register_foreign("sample_py_teh", 7, (void *)sample_py_teh, 0); functor_dp1 = PL_new_functor(PL_new_atom("dp"),1); functor_py2 = PL_new_functor(PL_new_atom("py"),2); functor_old1 = PL_new_functor(PL_new_atom("old"),1); functor_old2 = PL_new_functor(PL_new_atom("old"),2); atom_new = PL_new_atom("new"); } // unify Prolog BLOB with RndState structure static int unify_state(term_t state, RndState *S, PL_blob_t *pblob) { return PL_unify_blob(state, S, sizeof(RndState), pblob); } // extract RndState structure from Prolog BLOB static int get_state(term_t state, RndState *S0, PL_blob_t **ppblob) { size_t len; RndState *S; PL_get_blob(state, (void **)&S, &len, ppblob); *S0=*S; return TRUE; } // ----------------------------------------------------- // Prolog versions of functions to implement int counts_dist( term_t gem, term_t counts, size_t len, double *dist); int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len); void stoch(double *x, size_t len); /* %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det. % % Compute the probability Prob of observing X given a CRP % and a base probability of PProb. crp_prob( Alpha, classes(Counts,Vals), X, PProb, P) :- counts_dist( Alpha, Counts, Counts1), stoch( Counts1, Probs, _), maplist( equal(X), Vals, Mask), maplist( mul, [PProb | Mask], Probs, PostProbs), sumlist( PostProbs, P). */ foreign_t crp_prob( term_t Alpha, term_t Classes, term_t X, term_t PProb, term_t Prob) { term_t Counts=PL_new_term_ref(); term_t Vals=PL_new_term_ref(); double prob=0, pprob; double *dist=NULL; long len=0; int rc = get_double(PProb, &pprob) && get_classes(Classes, Counts, Vals, &len) && alloc_array(len+1, sizeof(double), (void **)&dist) && counts_dist(Alpha, Counts, len, dist); if (rc) { term_t Val = PL_new_term_ref(); int i; stoch(dist,len+1); prob = pprob*dist[0]; for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) { if (PL_unify(Val,X)) prob += dist[i]; } } else rc=0; if (dist) free(dist); return rc && PL_unify_float(Prob,prob); } /* %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det. % % Sample a new value from CRP, Action A is either new, which means % that the user should sample a new value from the base distribtion, % or old(X,C), where X is an old value and C is the index of its class. % Operates in random state DCG. crp_sample( Alpha, classes(Counts,Vals), Action, RS1, RS2) :- counts_dist(Alpha, Counts, Counts1), discrete(Counts1,Z,RS1,RS2), ( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C) ; Action=new). */ foreign_t crp_sample( term_t Alpha, term_t Classes, term_t Action, term_t Rnd1, term_t Rnd2) { term_t Counts=PL_new_term_ref(); term_t Vals=PL_new_term_ref(); PL_blob_t *blob; double *dist=NULL; RndState rs; long len=0; int rc = get_classes(Classes, Counts, Vals, &len) && alloc_array(len+1, sizeof(double), (void **)&dist) && counts_dist(Alpha, Counts, len, dist) && get_state(Rnd1,&rs,&blob); if (rc) { int z=Discrete( &rs, len+1, dist, sum_array(dist,len+1)); if (z==0) { rc = PL_unify_atom(Action,atom_new); } else { term_t X=PL_new_term_ref(); int i=0; while (i<z && PL_get_list(Vals,X,Vals)) i++; rc = (i==z) && PL_unify_term(Action, PL_FUNCTOR, functor_old2, PL_TERM, X, PL_INT, z); } } if (dist) free(dist); return rc && unify_state(Rnd2,&rs,blob); } /* %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det. % % Sample class appropriate for observation of value X. PProb is the % base probability of X from the base distribution. Action A is new % or old(Class). % Operates in random state DCG. crp_sample_obs( Alpha, classes(Counts,Vals), X, ProbX, A, RS1, RS2) :- counts_dist( Alpha, Counts, [CNew|Counts1]), PNew is CNew*ProbX, maplist( post_count(X),Vals,Counts1,Counts2), discrete( [PNew|Counts2], Z, RS1, RS2), ( Z=1 -> A=new; succ(C,Z), A=old(C)). */ foreign_t crp_sample_obs( term_t Alpha, term_t Classes, term_t X, term_t Probx, term_t Act, term_t Rnd1, term_t Rnd2) { term_t Counts=PL_new_term_ref(); term_t Vals=PL_new_term_ref(); PL_blob_t *blob; double probx=0; double *dist=NULL; long len=0; RndState rs; int rc = get_double(Probx,&probx) && get_classes(Classes, Counts, Vals, &len) && alloc_array(len+1, sizeof(double), (void **)&dist) && counts_dist(Alpha, Counts, len, dist) && get_state(Rnd1,&rs,&blob); if (rc) { term_t Val=PL_new_term_ref(); int i, z; dist[0] *= probx; for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) { if (!PL_unify(Val,X)) dist[i]=0; } z=Discrete( &rs, len+1, dist, sum_array(dist,len+1)); if (z==0) { rc = PL_unify_atom(Act,atom_new); } else { rc = PL_unify_term(Act, PL_FUNCTOR, functor_old1, PL_INT, z); } } if (dist) free(dist); return rc && unify_state(Rnd2,&rs,blob); } /* %% crp_sample_rm( +Classes:classes(A), +X:A, -C:natural)// is det. % % Sample appropriate class from which to remove value X. % Operates in random state DCG. crp_sample_rm( classes(Counts,Vals), X, Class, RS1, RS2) :- maplist(post_count(X),Vals,Counts,Counts1), discrete( Counts1, Class, RS1, RS2). */ foreign_t crp_sample_rm( term_t Classes, term_t X, term_t Class, term_t Rnd1, term_t Rnd2) { term_t Counts=PL_new_term_ref(); term_t Vals=PL_new_term_ref(); PL_blob_t *blob; double *dist=NULL; long len=0; RndState rs; int rc = get_classes(Classes, Counts, Vals, &len) && alloc_array(len, sizeof(double), (void **)&dist) && get_list_doubles(Counts, dist, len) && get_state(Rnd1,&rs,&blob); if (rc) { term_t Val=PL_new_term_ref(); int i, z; for (i=0; i<len && PL_get_list(Vals,Val,Vals); i++) { if (!PL_unify(Val,X)) dist[i]=0; } z = Discrete( &rs, len, dist, sum_array(dist,len)); rc = (z<len) && PL_unify_integer(Class, z+1); } if (dist) free(dist); return rc && unify_state(Rnd2,&rs,blob); } /* post_count(X,Val,Count,PC) :- X=Val -> PC=Count; PC=0. % ----------------------------------------------------------- % Dirichlet process and Pitman-Yor process % pseudo-counts models. counts_dist(_,[],0,[1]) :- !. counts_dist(dp(Alpha),Counts,_,[Alpha|Counts]) :- !. counts_dist(py(Alpha,Discount),Counts,K,[CNew|Counts1]) :- !, CNew is Alpha+Discount*K, maplist(sub(Discount),Counts,Counts1). */ int get_float_arg(int n,term_t Term, double *px) { term_t X=PL_new_term_ref(); return PL_get_arg(n,Term,X) && PL_get_float(X,px); } int counts_dist( term_t gem, term_t counts, size_t len, double *dist) { if (len==0) { dist[0]=1; return TRUE; } else { if (PL_is_functor(gem, functor_dp1)) { double alpha; term_t head=PL_new_term_ref(); int i, rc = get_float_arg(1,gem,&alpha); dist[0] = alpha; for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) { rc = rc && PL_get_float(head,&dist[i]); } return rc; } else if (PL_is_functor(gem, functor_py2)) { double theta, disc, c; term_t head=PL_new_term_ref(); int i, rc = get_float_arg(1,gem,&theta) && get_float_arg(2,gem,&disc); dist[0] = theta + disc*len; for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) { rc = rc && PL_get_float(head,&c); dist[i] = c-disc; } return rc; } else return FALSE; } } int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len) { term_t K=PL_new_term_ref(); return PL_get_arg(1,Classes,K) && PL_get_arg(2,Classes,Counts) && PL_get_arg(3,Classes,Vals) && PL_get_long(K,len); } void stoch(double *x, size_t len) { int i; double total=0, *xp; for (i=0, xp=x; i<len; i++, xp++) total += *xp; for (i=0, xp=x; i<len; i++, xp++) *xp /= total; } /* sample_dp_teh( ApSumKX, B, NX, dp(Alpha1), dp(Alpha2)) --> { Alpha1_1 is Alpha1+1 }, seqmap(beta(Alpha1_1),NX,WX), seqmap(bernoulli(Alpha1),NX,SX), { maplist(log,WX,LogWX), sumlist(SX,SumSX), sumlist(LogWX,SumLogWX), A1 is ApSumKX-SumSX, B1 is B-SumLogWX }, gamma(A1,B1,Alpha2). % run_left( seqmap(accum_log_beta(Alpha1_1),NX), 0, SumLogWX), % run_left( seqmap(accum_bernoulli(Alpha1),NX), 0, SumSX), %accum_log_beta(A,B) --> \> beta(A,B,X), { LogX is log(X) }, \< add(LogX). %accum_bernoulli(A,B) --> \> bernoulli(A,B,X), \< add(X). */ int Bernoulli(RndState *rs,double a,double b) { if ((a+b)*Uniform(rs)<b) return 1; else return 0; } foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2) { term_t N=PL_new_term_ref(); PL_blob_t *pblob; double apsumkx, b, alphap1; double alpha1=0, alpha2; double sum_log_wx, sum_sx; long n=0; RndState rs; int rc = get_double(ApSumKX, &apsumkx) && get_double(B, &b) && get_float_arg(1,p1,&alpha1) && get_state(rnd1,&rs,&pblob); alphap1 = alpha1+1; sum_log_wx = sum_sx = 0; while (rc && PL_get_list(NX,N,NX)) { rc = get_long(N,&n); sum_log_wx += log(Beta(&rs,alphap1,n)); sum_sx += Bernoulli(&rs,alpha1,n); } alpha2 = Gamma(&rs, apsumkx-sum_sx)/(b-sum_log_wx); return rc && PL_unify_term(p2, PL_FUNCTOR, functor_dp1, PL_FLOAT, alpha2) && unify_state(rnd2,&rs,pblob); } foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2) { PL_blob_t *pblob; term_t Counts = PL_new_term_ref(); term_t Count = PL_new_term_ref(); double theta_a, disc_a; double theta_b, disc_b; double theta1, disc1; double theta2, disc2; double theta1_1, disc1_1; double sum_log_wx, sum_sx, sum_nsx, sum_zx; RndState rs; int rc = get_float_arg(1,ThPrior,&theta_a) && get_float_arg(2,ThPrior,&theta_b) && get_float_arg(1,DPrior,&disc_a) && get_float_arg(2,DPrior,&disc_b) && get_float_arg(1,p1,&theta1) && get_float_arg(2,p1,&disc1) && get_state(rnd1,&rs,&pblob); theta1_1 = theta1+1; disc1_1 = 1-disc1; sum_log_wx = sum_sx = sum_nsx = sum_zx = 0; while (rc && PL_get_list(CountsX,Counts,CountsX)) { int n, k, i; long c=0; for(k=0, n=0; rc && PL_get_list(Counts,Count,Counts); k++, n+=c) { rc = get_long(Count,&c); if (k>0) { if (Bernoulli(&rs, disc1*k, theta1)) sum_sx++; else sum_nsx++; } for (i=0; i<c-1; i++) sum_zx += Bernoulli(&rs, i, disc1_1); } if (n>1) sum_log_wx += log(Beta(&rs, theta1_1, n-1)); } theta2 = Gamma(&rs, theta_a + sum_sx)/(theta_b-sum_log_wx); disc2 = Beta(&rs, disc_a + sum_nsx, disc_b + sum_zx); return rc && unify_state(rnd2,&rs,pblob) && PL_unify_term(p2, PL_FUNCTOR, functor_py2, PL_FLOAT, theta2, PL_FLOAT, disc2); } /* foreign_t sum_lengths( term_t Lists, term_t Total) { double total=0; size_t len=0; term_t List=PL_new_term_ref(); term_t Tail=PL_new_term_ref(); int rc=1; while (rc && PL_get_list(Lists,List,Lists)) { rc = PL_skip_list(List,Tail,&len); total += len; } return rc && PL_unify_integer(Total, total); } */