samer@0
|
1 /*
|
samer@0
|
2 */
|
samer@0
|
3
|
samer@0
|
4 #include <SWI-Prolog.h>
|
samer@0
|
5 #include <math.h>
|
samer@0
|
6 #include <float.h>
|
samer@0
|
7 #include <stdio.h>
|
samer@0
|
8
|
samer@0
|
9 #include "rndutils.h"
|
samer@0
|
10 #include "plutils.c"
|
samer@0
|
11
|
samer@0
|
12 install_t install();
|
samer@0
|
13
|
samer@0
|
14 foreign_t crp_prob( term_t alpha, term_t classes, term_t x, term_t pprob, term_t p);
|
samer@0
|
15 foreign_t crp_sample( term_t alpha, term_t classes, term_t action, term_t rnd1, term_t rnd2);
|
samer@0
|
16 foreign_t crp_sample_obs( term_t alpha, term_t classes, term_t x, term_t probx, term_t act, term_t rnd1, term_t rnd2);
|
samer@0
|
17 foreign_t crp_sample_rm( term_t classes, term_t x, term_t class, term_t rnd1, term_t rnd2);
|
samer@0
|
18 foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);
|
samer@0
|
19 foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);
|
samer@0
|
20
|
samer@0
|
21 static atom_t atom_new;
|
samer@0
|
22 static functor_t functor_old1, functor_old2;
|
samer@0
|
23 static functor_t functor_dp1, functor_py2;
|
samer@0
|
24
|
samer@0
|
25 install_t install() {
|
samer@0
|
26 PL_register_foreign("crp_prob", 5, (void *)crp_prob, 0);
|
samer@0
|
27 PL_register_foreign("crp_sample", 5, (void *)crp_sample, 0);
|
samer@0
|
28 PL_register_foreign("crp_sample_obs", 7, (void *)crp_sample_obs, 0);
|
samer@0
|
29 PL_register_foreign("crp_sample_rm", 5, (void *)crp_sample_rm, 0);
|
samer@0
|
30 PL_register_foreign("sample_dp_teh", 7, (void *)sample_dp_teh, 0);
|
samer@0
|
31 PL_register_foreign("sample_py_teh", 7, (void *)sample_py_teh, 0);
|
samer@0
|
32
|
samer@0
|
33 functor_dp1 = PL_new_functor(PL_new_atom("dp"),1);
|
samer@0
|
34 functor_py2 = PL_new_functor(PL_new_atom("py"),2);
|
samer@0
|
35 functor_old1 = PL_new_functor(PL_new_atom("old"),1);
|
samer@0
|
36 functor_old2 = PL_new_functor(PL_new_atom("old"),2);
|
samer@0
|
37 atom_new = PL_new_atom("new");
|
samer@0
|
38 }
|
samer@0
|
39
|
samer@0
|
40
|
samer@0
|
41 // unify Prolog BLOB with RndState structure
|
samer@0
|
42 static int unify_state(term_t state, RndState *S, PL_blob_t *pblob) {
|
samer@0
|
43 return PL_unify_blob(state, S, sizeof(RndState), pblob);
|
samer@0
|
44 }
|
samer@0
|
45
|
samer@0
|
46 // extract RndState structure from Prolog BLOB
|
samer@0
|
47 static int get_state(term_t state, RndState *S0, PL_blob_t **ppblob)
|
samer@0
|
48 {
|
samer@0
|
49 size_t len;
|
samer@0
|
50 RndState *S;
|
samer@0
|
51
|
samer@0
|
52 PL_get_blob(state, (void **)&S, &len, ppblob);
|
samer@0
|
53 *S0=*S;
|
samer@0
|
54 return TRUE;
|
samer@0
|
55 }
|
samer@0
|
56
|
samer@0
|
57
|
samer@0
|
58 // -----------------------------------------------------
|
samer@0
|
59 // Prolog versions of functions to implement
|
samer@0
|
60
|
samer@0
|
61 int counts_dist( term_t gem, term_t counts, size_t len, double *dist);
|
samer@0
|
62 int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len);
|
samer@0
|
63 void stoch(double *x, size_t len);
|
samer@0
|
64
|
samer@0
|
65 /*
|
samer@0
|
66 %% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det.
|
samer@0
|
67 %
|
samer@0
|
68 % Compute the probability Prob of observing X given a CRP
|
samer@0
|
69 % and a base probability of PProb.
|
samer@0
|
70 crp_prob( Alpha, classes(Counts,Vals), X, PProb, P) :-
|
samer@0
|
71 counts_dist( Alpha, Counts, Counts1),
|
samer@0
|
72 stoch( Counts1, Probs, _),
|
samer@0
|
73 maplist( equal(X), Vals, Mask),
|
samer@0
|
74 maplist( mul, [PProb | Mask], Probs, PostProbs),
|
samer@0
|
75 sumlist( PostProbs, P).
|
samer@0
|
76
|
samer@0
|
77 */
|
samer@0
|
78
|
samer@0
|
79 foreign_t crp_prob( term_t Alpha, term_t Classes, term_t X, term_t PProb, term_t Prob)
|
samer@0
|
80 {
|
samer@0
|
81 term_t Counts=PL_new_term_ref();
|
samer@0
|
82 term_t Vals=PL_new_term_ref();
|
samer@0
|
83 double prob=0, pprob;
|
samer@0
|
84 double *dist=NULL;
|
samer@0
|
85 long len=0;
|
samer@0
|
86
|
samer@0
|
87 int rc = get_double(PProb, &pprob)
|
samer@0
|
88 && get_classes(Classes, Counts, Vals, &len)
|
samer@0
|
89 && alloc_array(len+1, sizeof(double), (void **)&dist)
|
samer@0
|
90 && counts_dist(Alpha, Counts, len, dist);
|
samer@0
|
91
|
samer@0
|
92 if (rc) {
|
samer@0
|
93 term_t Val = PL_new_term_ref();
|
samer@0
|
94 int i;
|
samer@0
|
95
|
samer@0
|
96 stoch(dist,len+1);
|
samer@0
|
97 prob = pprob*dist[0];
|
samer@0
|
98 for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
|
samer@0
|
99 if (PL_unify(Val,X)) prob += dist[i];
|
samer@0
|
100 }
|
samer@0
|
101 } else rc=0;
|
samer@0
|
102 if (dist) free(dist);
|
samer@0
|
103 return rc && PL_unify_float(Prob,prob);
|
samer@0
|
104 }
|
samer@0
|
105
|
samer@0
|
106 /*
|
samer@0
|
107
|
samer@0
|
108
|
samer@0
|
109 %% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det.
|
samer@0
|
110 %
|
samer@0
|
111 % Sample a new value from CRP, Action A is either new, which means
|
samer@0
|
112 % that the user should sample a new value from the base distribtion,
|
samer@0
|
113 % or old(X,C), where X is an old value and C is the index of its class.
|
samer@0
|
114 % Operates in random state DCG.
|
samer@0
|
115 crp_sample( Alpha, classes(Counts,Vals), Action, RS1, RS2) :-
|
samer@0
|
116 counts_dist(Alpha, Counts, Counts1),
|
samer@0
|
117 discrete(Counts1,Z,RS1,RS2),
|
samer@0
|
118 ( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C)
|
samer@0
|
119 ; Action=new).
|
samer@0
|
120
|
samer@0
|
121 */
|
samer@0
|
122
|
samer@0
|
123 foreign_t crp_sample( term_t Alpha, term_t Classes, term_t Action, term_t Rnd1, term_t Rnd2)
|
samer@0
|
124 {
|
samer@0
|
125 term_t Counts=PL_new_term_ref();
|
samer@0
|
126 term_t Vals=PL_new_term_ref();
|
samer@0
|
127 PL_blob_t *blob;
|
samer@0
|
128 double *dist=NULL;
|
samer@0
|
129 RndState rs;
|
samer@0
|
130 long len=0;
|
samer@0
|
131
|
samer@0
|
132 int rc = get_classes(Classes, Counts, Vals, &len)
|
samer@0
|
133 && alloc_array(len+1, sizeof(double), (void **)&dist)
|
samer@0
|
134 && counts_dist(Alpha, Counts, len, dist)
|
samer@0
|
135 && get_state(Rnd1,&rs,&blob);
|
samer@0
|
136
|
samer@0
|
137 if (rc) {
|
samer@0
|
138 int z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));
|
samer@0
|
139
|
samer@0
|
140 if (z==0) { rc = PL_unify_atom(Action,atom_new); }
|
samer@0
|
141 else {
|
samer@0
|
142 term_t X=PL_new_term_ref();
|
samer@0
|
143 int i=0;
|
samer@0
|
144 while (i<z && PL_get_list(Vals,X,Vals)) i++;
|
samer@0
|
145 rc = (i==z) && PL_unify_term(Action, PL_FUNCTOR, functor_old2, PL_TERM, X, PL_INT, z);
|
samer@0
|
146 }
|
samer@0
|
147 }
|
samer@0
|
148 if (dist) free(dist);
|
samer@0
|
149 return rc && unify_state(Rnd2,&rs,blob);
|
samer@0
|
150 }
|
samer@0
|
151
|
samer@0
|
152 /*
|
samer@0
|
153
|
samer@0
|
154 %% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det.
|
samer@0
|
155 %
|
samer@0
|
156 % Sample class appropriate for observation of value X. PProb is the
|
samer@0
|
157 % base probability of X from the base distribution. Action A is new
|
samer@0
|
158 % or old(Class).
|
samer@0
|
159 % Operates in random state DCG.
|
samer@0
|
160 crp_sample_obs( Alpha, classes(Counts,Vals), X, ProbX, A, RS1, RS2) :-
|
samer@0
|
161 counts_dist( Alpha, Counts, [CNew|Counts1]),
|
samer@0
|
162 PNew is CNew*ProbX,
|
samer@0
|
163 maplist( post_count(X),Vals,Counts1,Counts2),
|
samer@0
|
164 discrete( [PNew|Counts2], Z, RS1, RS2),
|
samer@0
|
165 ( Z=1 -> A=new; succ(C,Z), A=old(C)).
|
samer@0
|
166
|
samer@0
|
167 */
|
samer@0
|
168
|
samer@0
|
169 foreign_t crp_sample_obs( term_t Alpha, term_t Classes, term_t X, term_t Probx, term_t Act, term_t Rnd1, term_t Rnd2)
|
samer@0
|
170 {
|
samer@0
|
171 term_t Counts=PL_new_term_ref();
|
samer@0
|
172 term_t Vals=PL_new_term_ref();
|
samer@0
|
173 PL_blob_t *blob;
|
samer@0
|
174 double probx=0;
|
samer@0
|
175 double *dist=NULL;
|
samer@0
|
176 long len=0;
|
samer@0
|
177 RndState rs;
|
samer@0
|
178
|
samer@0
|
179 int rc = get_double(Probx,&probx)
|
samer@0
|
180 && get_classes(Classes, Counts, Vals, &len)
|
samer@0
|
181 && alloc_array(len+1, sizeof(double), (void **)&dist)
|
samer@0
|
182 && counts_dist(Alpha, Counts, len, dist)
|
samer@0
|
183 && get_state(Rnd1,&rs,&blob);
|
samer@0
|
184
|
samer@0
|
185 if (rc) {
|
samer@0
|
186 term_t Val=PL_new_term_ref();
|
samer@0
|
187 int i, z;
|
samer@0
|
188
|
samer@0
|
189 dist[0] *= probx;
|
samer@0
|
190 for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
|
samer@0
|
191 if (!PL_unify(Val,X)) dist[i]=0;
|
samer@0
|
192 }
|
samer@0
|
193
|
samer@0
|
194 z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));
|
samer@0
|
195 if (z==0) { rc = PL_unify_atom(Act,atom_new); }
|
samer@0
|
196 else {
|
samer@0
|
197 rc = PL_unify_term(Act, PL_FUNCTOR, functor_old1, PL_INT, z);
|
samer@0
|
198 }
|
samer@0
|
199 }
|
samer@0
|
200 if (dist) free(dist);
|
samer@0
|
201 return rc && unify_state(Rnd2,&rs,blob);
|
samer@0
|
202 }
|
samer@0
|
203
|
samer@0
|
204 /*
|
samer@0
|
205 %% crp_sample_rm( +Classes:classes(A), +X:A, -C:natural)// is det.
|
samer@0
|
206 %
|
samer@0
|
207 % Sample appropriate class from which to remove value X.
|
samer@0
|
208 % Operates in random state DCG.
|
samer@0
|
209 crp_sample_rm( classes(Counts,Vals), X, Class, RS1, RS2) :-
|
samer@0
|
210 maplist(post_count(X),Vals,Counts,Counts1),
|
samer@0
|
211 discrete( Counts1, Class, RS1, RS2).
|
samer@0
|
212
|
samer@0
|
213 */
|
samer@0
|
214
|
samer@0
|
215 foreign_t crp_sample_rm( term_t Classes, term_t X, term_t Class, term_t Rnd1, term_t Rnd2)
|
samer@0
|
216 {
|
samer@0
|
217 term_t Counts=PL_new_term_ref();
|
samer@0
|
218 term_t Vals=PL_new_term_ref();
|
samer@0
|
219 PL_blob_t *blob;
|
samer@0
|
220 double *dist=NULL;
|
samer@0
|
221 long len=0;
|
samer@0
|
222 RndState rs;
|
samer@0
|
223
|
samer@0
|
224 int rc = get_classes(Classes, Counts, Vals, &len)
|
samer@0
|
225 && alloc_array(len, sizeof(double), (void **)&dist)
|
samer@0
|
226 && get_list_doubles(Counts, dist, len)
|
samer@0
|
227 && get_state(Rnd1,&rs,&blob);
|
samer@0
|
228
|
samer@0
|
229 if (rc) {
|
samer@0
|
230 term_t Val=PL_new_term_ref();
|
samer@0
|
231 int i, z;
|
samer@0
|
232
|
samer@0
|
233 for (i=0; i<len && PL_get_list(Vals,Val,Vals); i++) {
|
samer@0
|
234 if (!PL_unify(Val,X)) dist[i]=0;
|
samer@0
|
235 }
|
samer@0
|
236
|
samer@0
|
237 z = Discrete( &rs, len, dist, sum_array(dist,len));
|
samer@0
|
238 rc = (z<len) && PL_unify_integer(Class, z+1);
|
samer@0
|
239 }
|
samer@0
|
240 if (dist) free(dist);
|
samer@0
|
241 return rc && unify_state(Rnd2,&rs,blob);
|
samer@0
|
242 }
|
samer@0
|
243
|
samer@0
|
244 /*
|
samer@0
|
245 post_count(X,Val,Count,PC) :- X=Val -> PC=Count; PC=0.
|
samer@0
|
246
|
samer@0
|
247 % -----------------------------------------------------------
|
samer@0
|
248 % Dirichlet process and Pitman-Yor process
|
samer@0
|
249 % pseudo-counts models.
|
samer@0
|
250
|
samer@0
|
251 counts_dist(_,[],0,[1]) :- !.
|
samer@0
|
252 counts_dist(dp(Alpha),Counts,_,[Alpha|Counts]) :- !.
|
samer@0
|
253 counts_dist(py(Alpha,Discount),Counts,K,[CNew|Counts1]) :- !,
|
samer@0
|
254 CNew is Alpha+Discount*K,
|
samer@0
|
255 maplist(sub(Discount),Counts,Counts1).
|
samer@0
|
256
|
samer@0
|
257 */
|
samer@0
|
258
|
samer@0
|
259 int get_float_arg(int n,term_t Term, double *px)
|
samer@0
|
260 {
|
samer@0
|
261 term_t X=PL_new_term_ref();
|
samer@0
|
262 return PL_get_arg(n,Term,X) && PL_get_float(X,px);
|
samer@0
|
263 }
|
samer@0
|
264
|
samer@0
|
265 int counts_dist( term_t gem, term_t counts, size_t len, double *dist)
|
samer@0
|
266 {
|
samer@0
|
267 if (len==0) { dist[0]=1; return TRUE; }
|
samer@0
|
268 else {
|
samer@0
|
269 if (PL_is_functor(gem, functor_dp1)) {
|
samer@0
|
270 double alpha;
|
samer@0
|
271 term_t head=PL_new_term_ref();
|
samer@0
|
272 int i, rc = get_float_arg(1,gem,&alpha);
|
samer@0
|
273
|
samer@0
|
274 dist[0] = alpha;
|
samer@0
|
275 for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
|
samer@0
|
276 rc = rc && PL_get_float(head,&dist[i]);
|
samer@0
|
277 }
|
samer@0
|
278 return rc;
|
samer@0
|
279 } else if (PL_is_functor(gem, functor_py2)) {
|
samer@0
|
280 double theta, disc, c;
|
samer@0
|
281 term_t head=PL_new_term_ref();
|
samer@0
|
282
|
samer@0
|
283 int i, rc = get_float_arg(1,gem,&theta)
|
samer@0
|
284 && get_float_arg(2,gem,&disc);
|
samer@0
|
285
|
samer@0
|
286 dist[0] = theta + disc*len;
|
samer@0
|
287 for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
|
samer@0
|
288 rc = rc && PL_get_float(head,&c);
|
samer@0
|
289 dist[i] = c-disc;
|
samer@0
|
290 }
|
samer@0
|
291 return rc;
|
samer@0
|
292 } else return FALSE;
|
samer@0
|
293 }
|
samer@0
|
294 }
|
samer@0
|
295
|
samer@0
|
296 int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len)
|
samer@0
|
297 {
|
samer@0
|
298 term_t K=PL_new_term_ref();
|
samer@0
|
299
|
samer@0
|
300 return PL_get_arg(1,Classes,K)
|
samer@0
|
301 && PL_get_arg(2,Classes,Counts)
|
samer@0
|
302 && PL_get_arg(3,Classes,Vals)
|
samer@0
|
303 && PL_get_long(K,len);
|
samer@0
|
304 }
|
samer@0
|
305
|
samer@0
|
306
|
samer@0
|
307 void stoch(double *x, size_t len)
|
samer@0
|
308 {
|
samer@0
|
309 int i;
|
samer@0
|
310 double total=0, *xp;
|
samer@0
|
311 for (i=0, xp=x; i<len; i++, xp++) total += *xp;
|
samer@0
|
312 for (i=0, xp=x; i<len; i++, xp++) *xp /= total;
|
samer@0
|
313 }
|
samer@0
|
314
|
samer@0
|
315 /*
|
samer@0
|
316 sample_dp_teh( ApSumKX, B, NX, dp(Alpha1), dp(Alpha2)) -->
|
samer@0
|
317 { Alpha1_1 is Alpha1+1 },
|
samer@0
|
318 seqmap(beta(Alpha1_1),NX,WX),
|
samer@0
|
319 seqmap(bernoulli(Alpha1),NX,SX),
|
samer@0
|
320 { maplist(log,WX,LogWX),
|
samer@0
|
321 sumlist(SX,SumSX),
|
samer@0
|
322 sumlist(LogWX,SumLogWX),
|
samer@0
|
323 A1 is ApSumKX-SumSX, B1 is B-SumLogWX
|
samer@0
|
324 },
|
samer@0
|
325 gamma(A1,B1,Alpha2).
|
samer@0
|
326
|
samer@0
|
327 % run_left( seqmap(accum_log_beta(Alpha1_1),NX), 0, SumLogWX),
|
samer@0
|
328 % run_left( seqmap(accum_bernoulli(Alpha1),NX), 0, SumSX),
|
samer@0
|
329 %accum_log_beta(A,B) --> \> beta(A,B,X), { LogX is log(X) }, \< add(LogX).
|
samer@0
|
330 %accum_bernoulli(A,B) --> \> bernoulli(A,B,X), \< add(X).
|
samer@0
|
331
|
samer@0
|
332 */
|
samer@0
|
333
|
samer@0
|
334 int Bernoulli(RndState *rs,double a,double b) {
|
samer@0
|
335 if ((a+b)*Uniform(rs)<b) return 1; else return 0;
|
samer@0
|
336 }
|
samer@0
|
337
|
samer@0
|
338
|
samer@0
|
339 foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
|
samer@0
|
340 {
|
samer@0
|
341 term_t N=PL_new_term_ref();
|
samer@0
|
342 PL_blob_t *pblob;
|
samer@0
|
343 double apsumkx, b, alphap1;
|
samer@0
|
344 double alpha1=0, alpha2;
|
samer@0
|
345 double sum_log_wx, sum_sx;
|
samer@0
|
346 long n=0;
|
samer@0
|
347 RndState rs;
|
samer@0
|
348
|
samer@0
|
349 int rc = get_double(ApSumKX, &apsumkx)
|
samer@0
|
350 && get_double(B, &b)
|
samer@0
|
351 && get_float_arg(1,p1,&alpha1)
|
samer@0
|
352 && get_state(rnd1,&rs,&pblob);
|
samer@0
|
353
|
samer@0
|
354 alphap1 = alpha1+1;
|
samer@0
|
355 sum_log_wx = sum_sx = 0;
|
samer@0
|
356 while (rc && PL_get_list(NX,N,NX)) {
|
samer@0
|
357 rc = get_long(N,&n);
|
samer@0
|
358 sum_log_wx += log(Beta(&rs,alphap1,n));
|
samer@0
|
359 sum_sx += Bernoulli(&rs,alpha1,n);
|
samer@0
|
360 }
|
samer@0
|
361 alpha2 = Gamma(&rs, apsumkx-sum_sx)/(b-sum_log_wx);
|
samer@0
|
362
|
samer@0
|
363 return rc && PL_unify_term(p2, PL_FUNCTOR, functor_dp1, PL_FLOAT, alpha2)
|
samer@0
|
364 && unify_state(rnd2,&rs,pblob);
|
samer@0
|
365 }
|
samer@0
|
366
|
samer@0
|
367 foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
|
samer@0
|
368 {
|
samer@0
|
369 PL_blob_t *pblob;
|
samer@0
|
370 term_t Counts = PL_new_term_ref();
|
samer@0
|
371 term_t Count = PL_new_term_ref();
|
samer@0
|
372 double theta_a, disc_a;
|
samer@0
|
373 double theta_b, disc_b;
|
samer@0
|
374 double theta1, disc1;
|
samer@0
|
375 double theta2, disc2;
|
samer@0
|
376 double theta1_1, disc1_1;
|
samer@0
|
377 double sum_log_wx, sum_sx, sum_nsx, sum_zx;
|
samer@0
|
378 RndState rs;
|
samer@0
|
379
|
samer@0
|
380 int rc = get_float_arg(1,ThPrior,&theta_a)
|
samer@0
|
381 && get_float_arg(2,ThPrior,&theta_b)
|
samer@0
|
382 && get_float_arg(1,DPrior,&disc_a)
|
samer@0
|
383 && get_float_arg(2,DPrior,&disc_b)
|
samer@0
|
384 && get_float_arg(1,p1,&theta1)
|
samer@0
|
385 && get_float_arg(2,p1,&disc1)
|
samer@0
|
386 && get_state(rnd1,&rs,&pblob);
|
samer@0
|
387
|
samer@0
|
388 theta1_1 = theta1+1;
|
samer@0
|
389 disc1_1 = 1-disc1;
|
samer@0
|
390 sum_log_wx = sum_sx = sum_nsx = sum_zx = 0;
|
samer@0
|
391 while (rc && PL_get_list(CountsX,Counts,CountsX)) {
|
samer@0
|
392 int n, k, i;
|
samer@0
|
393 long c=0;
|
samer@0
|
394
|
samer@0
|
395 for(k=0, n=0; rc && PL_get_list(Counts,Count,Counts); k++, n+=c) {
|
samer@0
|
396 rc = get_long(Count,&c);
|
samer@0
|
397 if (k>0) { if (Bernoulli(&rs, disc1*k, theta1)) sum_sx++; else sum_nsx++; }
|
samer@0
|
398 for (i=0; i<c-1; i++) sum_zx += Bernoulli(&rs, i, disc1_1);
|
samer@0
|
399 }
|
samer@0
|
400 if (n>1) sum_log_wx += log(Beta(&rs, theta1_1, n-1));
|
samer@0
|
401 }
|
samer@0
|
402
|
samer@0
|
403 theta2 = Gamma(&rs, theta_a + sum_sx)/(theta_b-sum_log_wx);
|
samer@0
|
404 disc2 = Beta(&rs, disc_a + sum_nsx, disc_b + sum_zx);
|
samer@0
|
405 return rc && unify_state(rnd2,&rs,pblob)
|
samer@0
|
406 && PL_unify_term(p2, PL_FUNCTOR, functor_py2, PL_FLOAT, theta2, PL_FLOAT, disc2);
|
samer@0
|
407 }
|
samer@0
|
408
|
samer@0
|
409 /*
|
samer@0
|
410 foreign_t sum_lengths( term_t Lists, term_t Total)
|
samer@0
|
411 {
|
samer@0
|
412 double total=0;
|
samer@0
|
413 size_t len=0;
|
samer@0
|
414 term_t List=PL_new_term_ref();
|
samer@0
|
415 term_t Tail=PL_new_term_ref();
|
samer@0
|
416 int rc=1;
|
samer@0
|
417
|
samer@0
|
418 while (rc && PL_get_list(Lists,List,Lists)) {
|
samer@0
|
419 rc = PL_skip_list(List,Tail,&len);
|
samer@0
|
420 total += len;
|
samer@0
|
421 }
|
samer@0
|
422 return rc && PL_unify_integer(Total, total);
|
samer@0
|
423 }
|
samer@0
|
424 */
|