annotate crp.c @ 3:974d7be8eec4 tip

Update to pack-based dcg utilities
author samer
date Tue, 03 Oct 2017 11:52:23 +0100
parents b31415b4a196
children
rev   line source
samer@0 1 /*
samer@0 2 */
samer@0 3
samer@0 4 #include <SWI-Prolog.h>
samer@0 5 #include <math.h>
samer@0 6 #include <float.h>
samer@0 7 #include <stdio.h>
samer@0 8
samer@0 9 #include "rndutils.h"
samer@0 10 #include "plutils.c"
samer@0 11
samer@0 12 install_t install();
samer@0 13
samer@0 14 foreign_t crp_prob( term_t alpha, term_t classes, term_t x, term_t pprob, term_t p);
samer@0 15 foreign_t crp_sample( term_t alpha, term_t classes, term_t action, term_t rnd1, term_t rnd2);
samer@0 16 foreign_t crp_sample_obs( term_t alpha, term_t classes, term_t x, term_t probx, term_t act, term_t rnd1, term_t rnd2);
samer@0 17 foreign_t crp_sample_rm( term_t classes, term_t x, term_t class, term_t rnd1, term_t rnd2);
samer@0 18 foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);
samer@0 19 foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);
samer@0 20
samer@0 21 static atom_t atom_new;
samer@0 22 static functor_t functor_old1, functor_old2;
samer@0 23 static functor_t functor_dp1, functor_py2;
samer@0 24
samer@0 25 install_t install() {
samer@0 26 PL_register_foreign("crp_prob", 5, (void *)crp_prob, 0);
samer@0 27 PL_register_foreign("crp_sample", 5, (void *)crp_sample, 0);
samer@0 28 PL_register_foreign("crp_sample_obs", 7, (void *)crp_sample_obs, 0);
samer@0 29 PL_register_foreign("crp_sample_rm", 5, (void *)crp_sample_rm, 0);
samer@0 30 PL_register_foreign("sample_dp_teh", 7, (void *)sample_dp_teh, 0);
samer@0 31 PL_register_foreign("sample_py_teh", 7, (void *)sample_py_teh, 0);
samer@0 32
samer@0 33 functor_dp1 = PL_new_functor(PL_new_atom("dp"),1);
samer@0 34 functor_py2 = PL_new_functor(PL_new_atom("py"),2);
samer@0 35 functor_old1 = PL_new_functor(PL_new_atom("old"),1);
samer@0 36 functor_old2 = PL_new_functor(PL_new_atom("old"),2);
samer@0 37 atom_new = PL_new_atom("new");
samer@0 38 }
samer@0 39
samer@0 40
samer@0 41 // unify Prolog BLOB with RndState structure
samer@0 42 static int unify_state(term_t state, RndState *S, PL_blob_t *pblob) {
samer@0 43 return PL_unify_blob(state, S, sizeof(RndState), pblob);
samer@0 44 }
samer@0 45
samer@0 46 // extract RndState structure from Prolog BLOB
samer@0 47 static int get_state(term_t state, RndState *S0, PL_blob_t **ppblob)
samer@0 48 {
samer@0 49 size_t len;
samer@0 50 RndState *S;
samer@0 51
samer@0 52 PL_get_blob(state, (void **)&S, &len, ppblob);
samer@0 53 *S0=*S;
samer@0 54 return TRUE;
samer@0 55 }
samer@0 56
samer@0 57
samer@0 58 // -----------------------------------------------------
samer@0 59 // Prolog versions of functions to implement
samer@0 60
samer@0 61 int counts_dist( term_t gem, term_t counts, size_t len, double *dist);
samer@0 62 int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len);
samer@0 63 void stoch(double *x, size_t len);
samer@0 64
samer@0 65 /*
samer@0 66 %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det.
samer@0 67 %
samer@0 68 % Compute the probability Prob of observing X given a CRP
samer@0 69 % and a base probability of PProb.
samer@0 70 crp_prob( Alpha, classes(Counts,Vals), X, PProb, P) :-
samer@0 71 counts_dist( Alpha, Counts, Counts1),
samer@0 72 stoch( Counts1, Probs, _),
samer@0 73 maplist( equal(X), Vals, Mask),
samer@0 74 maplist( mul, [PProb | Mask], Probs, PostProbs),
samer@0 75 sumlist( PostProbs, P).
samer@0 76
samer@0 77 */
samer@0 78
samer@0 79 foreign_t crp_prob( term_t Alpha, term_t Classes, term_t X, term_t PProb, term_t Prob)
samer@0 80 {
samer@0 81 term_t Counts=PL_new_term_ref();
samer@0 82 term_t Vals=PL_new_term_ref();
samer@0 83 double prob=0, pprob;
samer@0 84 double *dist=NULL;
samer@0 85 long len=0;
samer@0 86
samer@0 87 int rc = get_double(PProb, &pprob)
samer@0 88 && get_classes(Classes, Counts, Vals, &len)
samer@0 89 && alloc_array(len+1, sizeof(double), (void **)&dist)
samer@0 90 && counts_dist(Alpha, Counts, len, dist);
samer@0 91
samer@0 92 if (rc) {
samer@0 93 term_t Val = PL_new_term_ref();
samer@0 94 int i;
samer@0 95
samer@0 96 stoch(dist,len+1);
samer@0 97 prob = pprob*dist[0];
samer@0 98 for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
samer@0 99 if (PL_unify(Val,X)) prob += dist[i];
samer@0 100 }
samer@0 101 } else rc=0;
samer@0 102 if (dist) free(dist);
samer@0 103 return rc && PL_unify_float(Prob,prob);
samer@0 104 }
samer@0 105
samer@0 106 /*
samer@0 107
samer@0 108
samer@0 109 %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det.
samer@0 110 %
samer@0 111 % Sample a new value from CRP, Action A is either new, which means
samer@0 112 % that the user should sample a new value from the base distribtion,
samer@0 113 % or old(X,C), where X is an old value and C is the index of its class.
samer@0 114 % Operates in random state DCG.
samer@0 115 crp_sample( Alpha, classes(Counts,Vals), Action, RS1, RS2) :-
samer@0 116 counts_dist(Alpha, Counts, Counts1),
samer@0 117 discrete(Counts1,Z,RS1,RS2),
samer@0 118 ( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C)
samer@0 119 ; Action=new).
samer@0 120
samer@0 121 */
samer@0 122
samer@0 123 foreign_t crp_sample( term_t Alpha, term_t Classes, term_t Action, term_t Rnd1, term_t Rnd2)
samer@0 124 {
samer@0 125 term_t Counts=PL_new_term_ref();
samer@0 126 term_t Vals=PL_new_term_ref();
samer@0 127 PL_blob_t *blob;
samer@0 128 double *dist=NULL;
samer@0 129 RndState rs;
samer@0 130 long len=0;
samer@0 131
samer@0 132 int rc = get_classes(Classes, Counts, Vals, &len)
samer@0 133 && alloc_array(len+1, sizeof(double), (void **)&dist)
samer@0 134 && counts_dist(Alpha, Counts, len, dist)
samer@0 135 && get_state(Rnd1,&rs,&blob);
samer@0 136
samer@0 137 if (rc) {
samer@0 138 int z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));
samer@0 139
samer@0 140 if (z==0) { rc = PL_unify_atom(Action,atom_new); }
samer@0 141 else {
samer@0 142 term_t X=PL_new_term_ref();
samer@0 143 int i=0;
samer@0 144 while (i<z && PL_get_list(Vals,X,Vals)) i++;
samer@0 145 rc = (i==z) && PL_unify_term(Action, PL_FUNCTOR, functor_old2, PL_TERM, X, PL_INT, z);
samer@0 146 }
samer@0 147 }
samer@0 148 if (dist) free(dist);
samer@0 149 return rc && unify_state(Rnd2,&rs,blob);
samer@0 150 }
samer@0 151
samer@0 152 /*
samer@0 153
samer@0 154 %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det.
samer@0 155 %
samer@0 156 % Sample class appropriate for observation of value X. PProb is the
samer@0 157 % base probability of X from the base distribution. Action A is new
samer@0 158 % or old(Class).
samer@0 159 % Operates in random state DCG.
samer@0 160 crp_sample_obs( Alpha, classes(Counts,Vals), X, ProbX, A, RS1, RS2) :-
samer@0 161 counts_dist( Alpha, Counts, [CNew|Counts1]),
samer@0 162 PNew is CNew*ProbX,
samer@0 163 maplist( post_count(X),Vals,Counts1,Counts2),
samer@0 164 discrete( [PNew|Counts2], Z, RS1, RS2),
samer@0 165 ( Z=1 -> A=new; succ(C,Z), A=old(C)).
samer@0 166
samer@0 167 */
samer@0 168
samer@0 169 foreign_t crp_sample_obs( term_t Alpha, term_t Classes, term_t X, term_t Probx, term_t Act, term_t Rnd1, term_t Rnd2)
samer@0 170 {
samer@0 171 term_t Counts=PL_new_term_ref();
samer@0 172 term_t Vals=PL_new_term_ref();
samer@0 173 PL_blob_t *blob;
samer@0 174 double probx=0;
samer@0 175 double *dist=NULL;
samer@0 176 long len=0;
samer@0 177 RndState rs;
samer@0 178
samer@0 179 int rc = get_double(Probx,&probx)
samer@0 180 && get_classes(Classes, Counts, Vals, &len)
samer@0 181 && alloc_array(len+1, sizeof(double), (void **)&dist)
samer@0 182 && counts_dist(Alpha, Counts, len, dist)
samer@0 183 && get_state(Rnd1,&rs,&blob);
samer@0 184
samer@0 185 if (rc) {
samer@0 186 term_t Val=PL_new_term_ref();
samer@0 187 int i, z;
samer@0 188
samer@0 189 dist[0] *= probx;
samer@0 190 for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
samer@0 191 if (!PL_unify(Val,X)) dist[i]=0;
samer@0 192 }
samer@0 193
samer@0 194 z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));
samer@0 195 if (z==0) { rc = PL_unify_atom(Act,atom_new); }
samer@0 196 else {
samer@0 197 rc = PL_unify_term(Act, PL_FUNCTOR, functor_old1, PL_INT, z);
samer@0 198 }
samer@0 199 }
samer@0 200 if (dist) free(dist);
samer@0 201 return rc && unify_state(Rnd2,&rs,blob);
samer@0 202 }
samer@0 203
samer@0 204 /*
samer@0 205 %% crp_sample_rm( +Classes:classes(A), +X:A, -C:natural)// is det.
samer@0 206 %
samer@0 207 % Sample appropriate class from which to remove value X.
samer@0 208 % Operates in random state DCG.
samer@0 209 crp_sample_rm( classes(Counts,Vals), X, Class, RS1, RS2) :-
samer@0 210 maplist(post_count(X),Vals,Counts,Counts1),
samer@0 211 discrete( Counts1, Class, RS1, RS2).
samer@0 212
samer@0 213 */
samer@0 214
samer@0 215 foreign_t crp_sample_rm( term_t Classes, term_t X, term_t Class, term_t Rnd1, term_t Rnd2)
samer@0 216 {
samer@0 217 term_t Counts=PL_new_term_ref();
samer@0 218 term_t Vals=PL_new_term_ref();
samer@0 219 PL_blob_t *blob;
samer@0 220 double *dist=NULL;
samer@0 221 long len=0;
samer@0 222 RndState rs;
samer@0 223
samer@0 224 int rc = get_classes(Classes, Counts, Vals, &len)
samer@0 225 && alloc_array(len, sizeof(double), (void **)&dist)
samer@0 226 && get_list_doubles(Counts, dist, len)
samer@0 227 && get_state(Rnd1,&rs,&blob);
samer@0 228
samer@0 229 if (rc) {
samer@0 230 term_t Val=PL_new_term_ref();
samer@0 231 int i, z;
samer@0 232
samer@0 233 for (i=0; i<len && PL_get_list(Vals,Val,Vals); i++) {
samer@0 234 if (!PL_unify(Val,X)) dist[i]=0;
samer@0 235 }
samer@0 236
samer@0 237 z = Discrete( &rs, len, dist, sum_array(dist,len));
samer@0 238 rc = (z<len) && PL_unify_integer(Class, z+1);
samer@0 239 }
samer@0 240 if (dist) free(dist);
samer@0 241 return rc && unify_state(Rnd2,&rs,blob);
samer@0 242 }
samer@0 243
samer@0 244 /*
samer@0 245 post_count(X,Val,Count,PC) :- X=Val -> PC=Count; PC=0.
samer@0 246
samer@0 247 % -----------------------------------------------------------
samer@0 248 % Dirichlet process and Pitman-Yor process
samer@0 249 % pseudo-counts models.
samer@0 250
samer@0 251 counts_dist(_,[],0,[1]) :- !.
samer@0 252 counts_dist(dp(Alpha),Counts,_,[Alpha|Counts]) :- !.
samer@0 253 counts_dist(py(Alpha,Discount),Counts,K,[CNew|Counts1]) :- !,
samer@0 254 CNew is Alpha+Discount*K,
samer@0 255 maplist(sub(Discount),Counts,Counts1).
samer@0 256
samer@0 257 */
samer@0 258
samer@0 259 int get_float_arg(int n,term_t Term, double *px)
samer@0 260 {
samer@0 261 term_t X=PL_new_term_ref();
samer@0 262 return PL_get_arg(n,Term,X) && PL_get_float(X,px);
samer@0 263 }
samer@0 264
samer@0 265 int counts_dist( term_t gem, term_t counts, size_t len, double *dist)
samer@0 266 {
samer@0 267 if (len==0) { dist[0]=1; return TRUE; }
samer@0 268 else {
samer@0 269 if (PL_is_functor(gem, functor_dp1)) {
samer@0 270 double alpha;
samer@0 271 term_t head=PL_new_term_ref();
samer@0 272 int i, rc = get_float_arg(1,gem,&alpha);
samer@0 273
samer@0 274 dist[0] = alpha;
samer@0 275 for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
samer@0 276 rc = rc && PL_get_float(head,&dist[i]);
samer@0 277 }
samer@0 278 return rc;
samer@0 279 } else if (PL_is_functor(gem, functor_py2)) {
samer@0 280 double theta, disc, c;
samer@0 281 term_t head=PL_new_term_ref();
samer@0 282
samer@0 283 int i, rc = get_float_arg(1,gem,&theta)
samer@0 284 && get_float_arg(2,gem,&disc);
samer@0 285
samer@0 286 dist[0] = theta + disc*len;
samer@0 287 for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
samer@0 288 rc = rc && PL_get_float(head,&c);
samer@0 289 dist[i] = c-disc;
samer@0 290 }
samer@0 291 return rc;
samer@0 292 } else return FALSE;
samer@0 293 }
samer@0 294 }
samer@0 295
samer@0 296 int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len)
samer@0 297 {
samer@0 298 term_t K=PL_new_term_ref();
samer@0 299
samer@0 300 return PL_get_arg(1,Classes,K)
samer@0 301 && PL_get_arg(2,Classes,Counts)
samer@0 302 && PL_get_arg(3,Classes,Vals)
samer@0 303 && PL_get_long(K,len);
samer@0 304 }
samer@0 305
samer@0 306
samer@0 307 void stoch(double *x, size_t len)
samer@0 308 {
samer@0 309 int i;
samer@0 310 double total=0, *xp;
samer@0 311 for (i=0, xp=x; i<len; i++, xp++) total += *xp;
samer@0 312 for (i=0, xp=x; i<len; i++, xp++) *xp /= total;
samer@0 313 }
samer@0 314
samer@0 315 /*
samer@0 316 sample_dp_teh( ApSumKX, B, NX, dp(Alpha1), dp(Alpha2)) -->
samer@0 317 { Alpha1_1 is Alpha1+1 },
samer@0 318 seqmap(beta(Alpha1_1),NX,WX),
samer@0 319 seqmap(bernoulli(Alpha1),NX,SX),
samer@0 320 { maplist(log,WX,LogWX),
samer@0 321 sumlist(SX,SumSX),
samer@0 322 sumlist(LogWX,SumLogWX),
samer@0 323 A1 is ApSumKX-SumSX, B1 is B-SumLogWX
samer@0 324 },
samer@0 325 gamma(A1,B1,Alpha2).
samer@0 326
samer@0 327 % run_left( seqmap(accum_log_beta(Alpha1_1),NX), 0, SumLogWX),
samer@0 328 % run_left( seqmap(accum_bernoulli(Alpha1),NX), 0, SumSX),
samer@0 329 %accum_log_beta(A,B) --> \> beta(A,B,X), { LogX is log(X) }, \< add(LogX).
samer@0 330 %accum_bernoulli(A,B) --> \> bernoulli(A,B,X), \< add(X).
samer@0 331
samer@0 332 */
samer@0 333
samer@0 334 int Bernoulli(RndState *rs,double a,double b) {
samer@0 335 if ((a+b)*Uniform(rs)<b) return 1; else return 0;
samer@0 336 }
samer@0 337
samer@0 338
samer@0 339 foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
samer@0 340 {
samer@0 341 term_t N=PL_new_term_ref();
samer@0 342 PL_blob_t *pblob;
samer@0 343 double apsumkx, b, alphap1;
samer@0 344 double alpha1=0, alpha2;
samer@0 345 double sum_log_wx, sum_sx;
samer@0 346 long n=0;
samer@0 347 RndState rs;
samer@0 348
samer@0 349 int rc = get_double(ApSumKX, &apsumkx)
samer@0 350 && get_double(B, &b)
samer@0 351 && get_float_arg(1,p1,&alpha1)
samer@0 352 && get_state(rnd1,&rs,&pblob);
samer@0 353
samer@0 354 alphap1 = alpha1+1;
samer@0 355 sum_log_wx = sum_sx = 0;
samer@0 356 while (rc && PL_get_list(NX,N,NX)) {
samer@0 357 rc = get_long(N,&n);
samer@0 358 sum_log_wx += log(Beta(&rs,alphap1,n));
samer@0 359 sum_sx += Bernoulli(&rs,alpha1,n);
samer@0 360 }
samer@0 361 alpha2 = Gamma(&rs, apsumkx-sum_sx)/(b-sum_log_wx);
samer@0 362
samer@0 363 return rc && PL_unify_term(p2, PL_FUNCTOR, functor_dp1, PL_FLOAT, alpha2)
samer@0 364 && unify_state(rnd2,&rs,pblob);
samer@0 365 }
samer@0 366
samer@0 367 foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
samer@0 368 {
samer@0 369 PL_blob_t *pblob;
samer@0 370 term_t Counts = PL_new_term_ref();
samer@0 371 term_t Count = PL_new_term_ref();
samer@0 372 double theta_a, disc_a;
samer@0 373 double theta_b, disc_b;
samer@0 374 double theta1, disc1;
samer@0 375 double theta2, disc2;
samer@0 376 double theta1_1, disc1_1;
samer@0 377 double sum_log_wx, sum_sx, sum_nsx, sum_zx;
samer@0 378 RndState rs;
samer@0 379
samer@0 380 int rc = get_float_arg(1,ThPrior,&theta_a)
samer@0 381 && get_float_arg(2,ThPrior,&theta_b)
samer@0 382 && get_float_arg(1,DPrior,&disc_a)
samer@0 383 && get_float_arg(2,DPrior,&disc_b)
samer@0 384 && get_float_arg(1,p1,&theta1)
samer@0 385 && get_float_arg(2,p1,&disc1)
samer@0 386 && get_state(rnd1,&rs,&pblob);
samer@0 387
samer@0 388 theta1_1 = theta1+1;
samer@0 389 disc1_1 = 1-disc1;
samer@0 390 sum_log_wx = sum_sx = sum_nsx = sum_zx = 0;
samer@0 391 while (rc && PL_get_list(CountsX,Counts,CountsX)) {
samer@0 392 int n, k, i;
samer@0 393 long c=0;
samer@0 394
samer@0 395 for(k=0, n=0; rc && PL_get_list(Counts,Count,Counts); k++, n+=c) {
samer@0 396 rc = get_long(Count,&c);
samer@0 397 if (k>0) { if (Bernoulli(&rs, disc1*k, theta1)) sum_sx++; else sum_nsx++; }
samer@0 398 for (i=0; i<c-1; i++) sum_zx += Bernoulli(&rs, i, disc1_1);
samer@0 399 }
samer@0 400 if (n>1) sum_log_wx += log(Beta(&rs, theta1_1, n-1));
samer@0 401 }
samer@0 402
samer@0 403 theta2 = Gamma(&rs, theta_a + sum_sx)/(theta_b-sum_log_wx);
samer@0 404 disc2 = Beta(&rs, disc_a + sum_nsx, disc_b + sum_zx);
samer@0 405 return rc && unify_state(rnd2,&rs,pblob)
samer@0 406 && PL_unify_term(p2, PL_FUNCTOR, functor_py2, PL_FLOAT, theta2, PL_FLOAT, disc2);
samer@0 407 }
samer@0 408
samer@0 409 /*
samer@0 410 foreign_t sum_lengths( term_t Lists, term_t Total)
samer@0 411 {
samer@0 412 double total=0;
samer@0 413 size_t len=0;
samer@0 414 term_t List=PL_new_term_ref();
samer@0 415 term_t Tail=PL_new_term_ref();
samer@0 416 int rc=1;
samer@0 417
samer@0 418 while (rc && PL_get_list(Lists,List,Lists)) {
samer@0 419 rc = PL_skip_list(List,Tail,&len);
samer@0 420 total += len;
samer@0 421 }
samer@0 422 return rc && PL_unify_integer(Total, total);
samer@0 423 }
samer@0 424 */