samer@0: /* samer@0: */ samer@0: samer@0: #include samer@0: #include samer@0: #include samer@0: #include samer@0: samer@0: #include "rndutils.h" samer@0: #include "plutils.c" samer@0: samer@0: install_t install(); samer@0: samer@0: foreign_t crp_prob( term_t alpha, term_t classes, term_t x, term_t pprob, term_t p); samer@0: foreign_t crp_sample( term_t alpha, term_t classes, term_t action, term_t rnd1, term_t rnd2); samer@0: foreign_t crp_sample_obs( term_t alpha, term_t classes, term_t x, term_t probx, term_t act, term_t rnd1, term_t rnd2); samer@0: foreign_t crp_sample_rm( term_t classes, term_t x, term_t class, term_t rnd1, term_t rnd2); samer@0: foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2); samer@0: foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2); samer@0: samer@0: static atom_t atom_new; samer@0: static functor_t functor_old1, functor_old2; samer@0: static functor_t functor_dp1, functor_py2; samer@0: samer@0: install_t install() { samer@0: PL_register_foreign("crp_prob", 5, (void *)crp_prob, 0); samer@0: PL_register_foreign("crp_sample", 5, (void *)crp_sample, 0); samer@0: PL_register_foreign("crp_sample_obs", 7, (void *)crp_sample_obs, 0); samer@0: PL_register_foreign("crp_sample_rm", 5, (void *)crp_sample_rm, 0); samer@0: PL_register_foreign("sample_dp_teh", 7, (void *)sample_dp_teh, 0); samer@0: PL_register_foreign("sample_py_teh", 7, (void *)sample_py_teh, 0); samer@0: samer@0: functor_dp1 = PL_new_functor(PL_new_atom("dp"),1); samer@0: functor_py2 = PL_new_functor(PL_new_atom("py"),2); samer@0: functor_old1 = PL_new_functor(PL_new_atom("old"),1); samer@0: functor_old2 = PL_new_functor(PL_new_atom("old"),2); samer@0: atom_new = PL_new_atom("new"); samer@0: } samer@0: samer@0: samer@0: // unify Prolog BLOB with RndState structure samer@0: static int unify_state(term_t state, RndState *S, PL_blob_t *pblob) { samer@0: return PL_unify_blob(state, S, sizeof(RndState), pblob); samer@0: } samer@0: samer@0: // extract RndState structure from Prolog BLOB samer@0: static int get_state(term_t state, RndState *S0, PL_blob_t **ppblob) samer@0: { samer@0: size_t len; samer@0: RndState *S; samer@0: samer@0: PL_get_blob(state, (void **)&S, &len, ppblob); samer@0: *S0=*S; samer@0: return TRUE; samer@0: } samer@0: samer@0: samer@0: // ----------------------------------------------------- samer@0: // Prolog versions of functions to implement samer@0: samer@0: int counts_dist( term_t gem, term_t counts, size_t len, double *dist); samer@0: int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len); samer@0: void stoch(double *x, size_t len); samer@0: samer@0: /* samer@0: %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det. samer@0: % samer@0: % Compute the probability Prob of observing X given a CRP samer@0: % and a base probability of PProb. samer@0: crp_prob( Alpha, classes(Counts,Vals), X, PProb, P) :- samer@0: counts_dist( Alpha, Counts, Counts1), samer@0: stoch( Counts1, Probs, _), samer@0: maplist( equal(X), Vals, Mask), samer@0: maplist( mul, [PProb | Mask], Probs, PostProbs), samer@0: sumlist( PostProbs, P). samer@0: samer@0: */ samer@0: samer@0: foreign_t crp_prob( term_t Alpha, term_t Classes, term_t X, term_t PProb, term_t Prob) samer@0: { samer@0: term_t Counts=PL_new_term_ref(); samer@0: term_t Vals=PL_new_term_ref(); samer@0: double prob=0, pprob; samer@0: double *dist=NULL; samer@0: long len=0; samer@0: samer@0: int rc = get_double(PProb, &pprob) samer@0: && get_classes(Classes, Counts, Vals, &len) samer@0: && alloc_array(len+1, sizeof(double), (void **)&dist) samer@0: && counts_dist(Alpha, Counts, len, dist); samer@0: samer@0: if (rc) { samer@0: term_t Val = PL_new_term_ref(); samer@0: int i; samer@0: samer@0: stoch(dist,len+1); samer@0: prob = pprob*dist[0]; samer@0: for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) { samer@0: if (PL_unify(Val,X)) prob += dist[i]; samer@0: } samer@0: } else rc=0; samer@0: if (dist) free(dist); samer@0: return rc && PL_unify_float(Prob,prob); samer@0: } samer@0: samer@0: /* samer@0: samer@0: samer@0: %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det. samer@0: % samer@0: % Sample a new value from CRP, Action A is either new, which means samer@0: % that the user should sample a new value from the base distribtion, samer@0: % or old(X,C), where X is an old value and C is the index of its class. samer@0: % Operates in random state DCG. samer@0: crp_sample( Alpha, classes(Counts,Vals), Action, RS1, RS2) :- samer@0: counts_dist(Alpha, Counts, Counts1), samer@0: discrete(Counts1,Z,RS1,RS2), samer@0: ( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C) samer@0: ; Action=new). samer@0: samer@0: */ samer@0: samer@0: foreign_t crp_sample( term_t Alpha, term_t Classes, term_t Action, term_t Rnd1, term_t Rnd2) samer@0: { samer@0: term_t Counts=PL_new_term_ref(); samer@0: term_t Vals=PL_new_term_ref(); samer@0: PL_blob_t *blob; samer@0: double *dist=NULL; samer@0: RndState rs; samer@0: long len=0; samer@0: samer@0: int rc = get_classes(Classes, Counts, Vals, &len) samer@0: && alloc_array(len+1, sizeof(double), (void **)&dist) samer@0: && counts_dist(Alpha, Counts, len, dist) samer@0: && get_state(Rnd1,&rs,&blob); samer@0: samer@0: if (rc) { samer@0: int z=Discrete( &rs, len+1, dist, sum_array(dist,len+1)); samer@0: samer@0: if (z==0) { rc = PL_unify_atom(Action,atom_new); } samer@0: else { samer@0: term_t X=PL_new_term_ref(); samer@0: int i=0; samer@0: while (i A=new; succ(C,Z), A=old(C)). samer@0: samer@0: */ samer@0: samer@0: foreign_t crp_sample_obs( term_t Alpha, term_t Classes, term_t X, term_t Probx, term_t Act, term_t Rnd1, term_t Rnd2) samer@0: { samer@0: term_t Counts=PL_new_term_ref(); samer@0: term_t Vals=PL_new_term_ref(); samer@0: PL_blob_t *blob; samer@0: double probx=0; samer@0: double *dist=NULL; samer@0: long len=0; samer@0: RndState rs; samer@0: samer@0: int rc = get_double(Probx,&probx) samer@0: && get_classes(Classes, Counts, Vals, &len) samer@0: && alloc_array(len+1, sizeof(double), (void **)&dist) samer@0: && counts_dist(Alpha, Counts, len, dist) samer@0: && get_state(Rnd1,&rs,&blob); samer@0: samer@0: if (rc) { samer@0: term_t Val=PL_new_term_ref(); samer@0: int i, z; samer@0: samer@0: dist[0] *= probx; samer@0: for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) { samer@0: if (!PL_unify(Val,X)) dist[i]=0; samer@0: } samer@0: samer@0: z=Discrete( &rs, len+1, dist, sum_array(dist,len+1)); samer@0: if (z==0) { rc = PL_unify_atom(Act,atom_new); } samer@0: else { samer@0: rc = PL_unify_term(Act, PL_FUNCTOR, functor_old1, PL_INT, z); samer@0: } samer@0: } samer@0: if (dist) free(dist); samer@0: return rc && unify_state(Rnd2,&rs,blob); samer@0: } samer@0: samer@0: /* samer@0: %% crp_sample_rm( +Classes:classes(A), +X:A, -C:natural)// is det. samer@0: % samer@0: % Sample appropriate class from which to remove value X. samer@0: % Operates in random state DCG. samer@0: crp_sample_rm( classes(Counts,Vals), X, Class, RS1, RS2) :- samer@0: maplist(post_count(X),Vals,Counts,Counts1), samer@0: discrete( Counts1, Class, RS1, RS2). samer@0: samer@0: */ samer@0: samer@0: foreign_t crp_sample_rm( term_t Classes, term_t X, term_t Class, term_t Rnd1, term_t Rnd2) samer@0: { samer@0: term_t Counts=PL_new_term_ref(); samer@0: term_t Vals=PL_new_term_ref(); samer@0: PL_blob_t *blob; samer@0: double *dist=NULL; samer@0: long len=0; samer@0: RndState rs; samer@0: samer@0: int rc = get_classes(Classes, Counts, Vals, &len) samer@0: && alloc_array(len, sizeof(double), (void **)&dist) samer@0: && get_list_doubles(Counts, dist, len) samer@0: && get_state(Rnd1,&rs,&blob); samer@0: samer@0: if (rc) { samer@0: term_t Val=PL_new_term_ref(); samer@0: int i, z; samer@0: samer@0: for (i=0; i PC=Count; PC=0. samer@0: samer@0: % ----------------------------------------------------------- samer@0: % Dirichlet process and Pitman-Yor process samer@0: % pseudo-counts models. samer@0: samer@0: counts_dist(_,[],0,[1]) :- !. samer@0: counts_dist(dp(Alpha),Counts,_,[Alpha|Counts]) :- !. samer@0: counts_dist(py(Alpha,Discount),Counts,K,[CNew|Counts1]) :- !, samer@0: CNew is Alpha+Discount*K, samer@0: maplist(sub(Discount),Counts,Counts1). samer@0: samer@0: */ samer@0: samer@0: int get_float_arg(int n,term_t Term, double *px) samer@0: { samer@0: term_t X=PL_new_term_ref(); samer@0: return PL_get_arg(n,Term,X) && PL_get_float(X,px); samer@0: } samer@0: samer@0: int counts_dist( term_t gem, term_t counts, size_t len, double *dist) samer@0: { samer@0: if (len==0) { dist[0]=1; return TRUE; } samer@0: else { samer@0: if (PL_is_functor(gem, functor_dp1)) { samer@0: double alpha; samer@0: term_t head=PL_new_term_ref(); samer@0: int i, rc = get_float_arg(1,gem,&alpha); samer@0: samer@0: dist[0] = alpha; samer@0: for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) { samer@0: rc = rc && PL_get_float(head,&dist[i]); samer@0: } samer@0: return rc; samer@0: } else if (PL_is_functor(gem, functor_py2)) { samer@0: double theta, disc, c; samer@0: term_t head=PL_new_term_ref(); samer@0: samer@0: int i, rc = get_float_arg(1,gem,&theta) samer@0: && get_float_arg(2,gem,&disc); samer@0: samer@0: dist[0] = theta + disc*len; samer@0: for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) { samer@0: rc = rc && PL_get_float(head,&c); samer@0: dist[i] = c-disc; samer@0: } samer@0: return rc; samer@0: } else return FALSE; samer@0: } samer@0: } samer@0: samer@0: int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len) samer@0: { samer@0: term_t K=PL_new_term_ref(); samer@0: samer@0: return PL_get_arg(1,Classes,K) samer@0: && PL_get_arg(2,Classes,Counts) samer@0: && PL_get_arg(3,Classes,Vals) samer@0: && PL_get_long(K,len); samer@0: } samer@0: samer@0: samer@0: void stoch(double *x, size_t len) samer@0: { samer@0: int i; samer@0: double total=0, *xp; samer@0: for (i=0, xp=x; i samer@0: { Alpha1_1 is Alpha1+1 }, samer@0: seqmap(beta(Alpha1_1),NX,WX), samer@0: seqmap(bernoulli(Alpha1),NX,SX), samer@0: { maplist(log,WX,LogWX), samer@0: sumlist(SX,SumSX), samer@0: sumlist(LogWX,SumLogWX), samer@0: A1 is ApSumKX-SumSX, B1 is B-SumLogWX samer@0: }, samer@0: gamma(A1,B1,Alpha2). samer@0: samer@0: % run_left( seqmap(accum_log_beta(Alpha1_1),NX), 0, SumLogWX), samer@0: % run_left( seqmap(accum_bernoulli(Alpha1),NX), 0, SumSX), samer@0: %accum_log_beta(A,B) --> \> beta(A,B,X), { LogX is log(X) }, \< add(LogX). samer@0: %accum_bernoulli(A,B) --> \> bernoulli(A,B,X), \< add(X). samer@0: samer@0: */ samer@0: samer@0: int Bernoulli(RndState *rs,double a,double b) { samer@0: if ((a+b)*Uniform(rs)0) { if (Bernoulli(&rs, disc1*k, theta1)) sum_sx++; else sum_nsx++; } samer@0: for (i=0; i1) sum_log_wx += log(Beta(&rs, theta1_1, n-1)); samer@0: } samer@0: samer@0: theta2 = Gamma(&rs, theta_a + sum_sx)/(theta_b-sum_log_wx); samer@0: disc2 = Beta(&rs, disc_a + sum_nsx, disc_b + sum_zx); samer@0: return rc && unify_state(rnd2,&rs,pblob) samer@0: && PL_unify_term(p2, PL_FUNCTOR, functor_py2, PL_FLOAT, theta2, PL_FLOAT, disc2); samer@0: } samer@0: samer@0: /* samer@0: foreign_t sum_lengths( term_t Lists, term_t Total) samer@0: { samer@0: double total=0; samer@0: size_t len=0; samer@0: term_t List=PL_new_term_ref(); samer@0: term_t Tail=PL_new_term_ref(); samer@0: int rc=1; samer@0: samer@0: while (rc && PL_get_list(Lists,List,Lists)) { samer@0: rc = PL_skip_list(List,Tail,&len); samer@0: total += len; samer@0: } samer@0: return rc && PL_unify_integer(Total, total); samer@0: } samer@0: */