view crp.c @ 3:974d7be8eec4 tip

Update to pack-based dcg utilities
author samer
date Tue, 03 Oct 2017 11:52:23 +0100
parents b31415b4a196
children
line wrap: on
line source
/*
 */

#include <SWI-Prolog.h>
#include <math.h>
#include <float.h>
#include <stdio.h>

#include "rndutils.h"
#include "plutils.c"

install_t install();

foreign_t crp_prob( term_t alpha, term_t classes, term_t x, term_t pprob, term_t p);
foreign_t crp_sample( term_t alpha, term_t classes, term_t action, term_t rnd1, term_t rnd2);
foreign_t crp_sample_obs( term_t alpha, term_t classes, term_t x, term_t probx, term_t act, term_t rnd1, term_t rnd2);
foreign_t crp_sample_rm( term_t classes, term_t x, term_t class, term_t rnd1, term_t rnd2);
foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);
foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2);

static atom_t atom_new;
static functor_t functor_old1, functor_old2;
static functor_t functor_dp1, functor_py2;

install_t install() { 
	PL_register_foreign("crp_prob", 5, (void *)crp_prob, 0);
	PL_register_foreign("crp_sample", 5, (void *)crp_sample, 0);
	PL_register_foreign("crp_sample_obs", 7, (void *)crp_sample_obs, 0);
	PL_register_foreign("crp_sample_rm", 5, (void *)crp_sample_rm, 0);
	PL_register_foreign("sample_dp_teh", 7, (void *)sample_dp_teh, 0);
	PL_register_foreign("sample_py_teh", 7, (void *)sample_py_teh, 0);

	functor_dp1  = PL_new_functor(PL_new_atom("dp"),1);
	functor_py2  = PL_new_functor(PL_new_atom("py"),2);
	functor_old1 = PL_new_functor(PL_new_atom("old"),1);
	functor_old2 = PL_new_functor(PL_new_atom("old"),2);
	atom_new     = PL_new_atom("new");
}


// unify Prolog BLOB with RndState structure
static int unify_state(term_t state, RndState *S, PL_blob_t *pblob) {
	return PL_unify_blob(state, S, sizeof(RndState), pblob); 
}

// extract RndState structure from Prolog BLOB
static int get_state(term_t state, RndState *S0, PL_blob_t **ppblob)
{ 
	size_t    len;
	RndState *S;
  
	PL_get_blob(state, (void **)&S, &len, ppblob);
	*S0=*S;
	return TRUE;
} 


// -----------------------------------------------------
// Prolog versions of functions to implement 

int counts_dist( term_t gem, term_t counts, size_t len, double *dist);
int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len);
void stoch(double *x, size_t len);

/*
%% crp_prob( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -Prob:float) is det.
%
%  Compute the probability Prob of observing X given a CRP
%  and a base probability of PProb.
crp_prob( Alpha, classes(Counts,Vals), X, PProb, P) :-
	counts_dist( Alpha, Counts, Counts1),
	stoch( Counts1, Probs, _),
	maplist( equal(X), Vals, Mask),
	maplist( mul, [PProb | Mask], Probs, PostProbs),
	sumlist( PostProbs, P).

*/

foreign_t crp_prob( term_t Alpha, term_t Classes, term_t X, term_t PProb, term_t Prob)
{
	term_t Counts=PL_new_term_ref();
	term_t Vals=PL_new_term_ref();
	double prob=0, pprob;
	double *dist=NULL;
	long len=0;

	int rc =	get_double(PProb, &pprob)
		&&	get_classes(Classes, Counts, Vals, &len)
		&& alloc_array(len+1, sizeof(double), (void **)&dist)
		&& counts_dist(Alpha, Counts, len, dist);

	if (rc) {
		term_t Val = PL_new_term_ref();
		int i;

		stoch(dist,len+1);
		prob = pprob*dist[0];
		for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
			if (PL_unify(Val,X)) prob += dist[i];
		}
	} else rc=0;
	if (dist) free(dist);
	return rc && PL_unify_float(Prob,prob);
}

/*


%% crp_sample( +GEM:gem_model, +Classes:classes(A), -A:action(A))// is det.
%
%  Sample a new value from CRP, Action A is either new, which means
%  that the user should sample a new value from the base distribtion,
%  or old(X,C), where X is an old value and C is the index of its class.
%  Operates in random state DCG.
crp_sample( Alpha, classes(Counts,Vals), Action, RS1, RS2) :-
	counts_dist(Alpha, Counts, Counts1),
	discrete(Counts1,Z,RS1,RS2),
	( Z>1 -> succ(C,Z), nth1(C,Vals,X), Action=old(X,C)
	; Action=new).

*/

foreign_t crp_sample( term_t Alpha, term_t Classes, term_t Action, term_t Rnd1, term_t Rnd2)
{
	term_t Counts=PL_new_term_ref();
	term_t Vals=PL_new_term_ref();
	PL_blob_t *blob;
	double *dist=NULL;
	RndState rs;
	long len=0;

	int rc =	get_classes(Classes, Counts, Vals, &len)
		&& alloc_array(len+1, sizeof(double), (void **)&dist)
		&& counts_dist(Alpha, Counts, len, dist)
		&& get_state(Rnd1,&rs,&blob);

	if (rc) {
		int z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));

		if (z==0) { rc = PL_unify_atom(Action,atom_new); }
		else { 
			term_t X=PL_new_term_ref();
			int 	i=0;
			while (i<z && PL_get_list(Vals,X,Vals)) i++;
			rc = (i==z) && PL_unify_term(Action, PL_FUNCTOR, functor_old2, PL_TERM, X, PL_INT, z);
		}
	}
	if (dist) free(dist); 
	return rc && unify_state(Rnd2,&rs,blob);
}

/*

%% crp_sample_obs( +GEM:gem_model, +Classes:classes(A), +X:A, +PProb:float, -A:action)// is det.
%
%  Sample class appropriate for observation of value X. PProb is the
%  base probability of X from the base distribution. Action A is new
%  or old(Class).
%  Operates in random state DCG.
crp_sample_obs( Alpha, classes(Counts,Vals), X, ProbX, A, RS1, RS2) :-
	counts_dist( Alpha, Counts, [CNew|Counts1]),	
	PNew is CNew*ProbX,
	maplist( post_count(X),Vals,Counts1,Counts2),
	discrete( [PNew|Counts2], Z, RS1, RS2),
	( Z=1 -> A=new; succ(C,Z), A=old(C)).

*/

foreign_t crp_sample_obs( term_t Alpha, term_t Classes, term_t X, term_t Probx, term_t Act, term_t Rnd1, term_t Rnd2)
{
	term_t 		Counts=PL_new_term_ref();
	term_t 		Vals=PL_new_term_ref();
	PL_blob_t 	*blob;
	double	probx=0;
	double 	*dist=NULL;
	long 		len=0;
	RndState rs;

	int rc =	get_double(Probx,&probx)
		&& get_classes(Classes, Counts, Vals, &len)
		&& alloc_array(len+1, sizeof(double), (void **)&dist)
		&& counts_dist(Alpha, Counts, len, dist)
		&& get_state(Rnd1,&rs,&blob);

	if (rc) {
		term_t 	Val=PL_new_term_ref();
		int		i, z;

		dist[0] *= probx;
		for (i=1; i<=len && PL_get_list(Vals,Val,Vals); i++) {
			if (!PL_unify(Val,X)) dist[i]=0;
		}

		z=Discrete( &rs, len+1, dist, sum_array(dist,len+1));
		if (z==0) { rc = PL_unify_atom(Act,atom_new); }
		else { 
			rc = PL_unify_term(Act, PL_FUNCTOR, functor_old1, PL_INT, z);
		}
	}
	if (dist) free(dist); 
	return rc && unify_state(Rnd2,&rs,blob);
}

/*
%% crp_sample_rm( +Classes:classes(A), +X:A, -C:natural)// is det.
%
%  Sample appropriate class from which to remove value X.
%  Operates in random state DCG.
crp_sample_rm( classes(Counts,Vals), X, Class, RS1, RS2) :-
	maplist(post_count(X),Vals,Counts,Counts1),
	discrete( Counts1, Class, RS1, RS2).

*/

foreign_t crp_sample_rm( term_t Classes, term_t X, term_t Class, term_t Rnd1, term_t Rnd2)
{
	term_t 		Counts=PL_new_term_ref();
	term_t 		Vals=PL_new_term_ref();
	PL_blob_t 	*blob;
	double 	*dist=NULL;
	long 		len=0;
	RndState rs;

	int rc =	get_classes(Classes, Counts, Vals, &len)
		&& alloc_array(len, sizeof(double), (void **)&dist)
		&& get_list_doubles(Counts, dist, len)
		&& get_state(Rnd1,&rs,&blob);

	if (rc) {
		term_t 	Val=PL_new_term_ref();
		int		i, z;

		for (i=0; i<len && PL_get_list(Vals,Val,Vals); i++) {
			if (!PL_unify(Val,X)) dist[i]=0;
		}

		z = Discrete( &rs, len, dist, sum_array(dist,len));
		rc = (z<len) && PL_unify_integer(Class, z+1);
	}
	if (dist) free(dist); 
	return rc && unify_state(Rnd2,&rs,blob);
}

/*
post_count(X,Val,Count,PC) :- X=Val -> PC=Count; PC=0.

% -----------------------------------------------------------
% Dirichlet process and Pitman-Yor process
% pseudo-counts models.

counts_dist(_,[],0,[1]) :- !.
counts_dist(dp(Alpha),Counts,_,[Alpha|Counts]) :- !.
counts_dist(py(Alpha,Discount),Counts,K,[CNew|Counts1]) :- !,
	CNew is Alpha+Discount*K,
	maplist(sub(Discount),Counts,Counts1).

*/

int get_float_arg(int n,term_t Term, double *px)
{
	term_t X=PL_new_term_ref();
	return PL_get_arg(n,Term,X) && PL_get_float(X,px);
}

int counts_dist( term_t gem, term_t counts, size_t len, double *dist)
{
	if (len==0) { dist[0]=1; return TRUE; }
	else {
		if (PL_is_functor(gem, functor_dp1)) {
			double alpha;
			term_t head=PL_new_term_ref();
			int i, rc = get_float_arg(1,gem,&alpha);

			dist[0] = alpha;
			for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
				rc = rc && PL_get_float(head,&dist[i]);
			}
			return rc;
		} else if (PL_is_functor(gem, functor_py2)) {
			double	theta, disc, c;
			term_t	head=PL_new_term_ref();

			int i, rc = get_float_arg(1,gem,&theta)
						&& get_float_arg(2,gem,&disc);

			dist[0] = theta + disc*len;
			for(i=1; rc && i<=len && PL_get_list(counts,head,counts); i++) {
				rc = rc && PL_get_float(head,&c);
				dist[i] = c-disc;
			}
			return rc;
		} else return FALSE;
	}
}

int get_classes(term_t Classes, term_t Counts, term_t Vals, long *len)
{
	term_t K=PL_new_term_ref();

	return 	PL_get_arg(1,Classes,K)
			&& PL_get_arg(2,Classes,Counts)
			&& PL_get_arg(3,Classes,Vals)
			&& PL_get_long(K,len);
}


void stoch(double *x, size_t len)
{
	int i;
	double total=0, *xp;
	for (i=0, xp=x; i<len; i++, xp++) total += *xp;
	for (i=0, xp=x; i<len; i++, xp++) *xp /= total; 
}

/*
sample_dp_teh( ApSumKX, B, NX, dp(Alpha1), dp(Alpha2)) -->
	{ Alpha1_1 is Alpha1+1  },
	seqmap(beta(Alpha1_1),NX,WX),
	seqmap(bernoulli(Alpha1),NX,SX),
	{	maplist(log,WX,LogWX),
		sumlist(SX,SumSX), 
		sumlist(LogWX,SumLogWX), 
		A1 is ApSumKX-SumSX, B1 is B-SumLogWX
	},
	gamma(A1,B1,Alpha2).

%	run_left( seqmap(accum_log_beta(Alpha1_1),NX), 0, SumLogWX), 
%	run_left( seqmap(accum_bernoulli(Alpha1),NX), 0, SumSX), 
%accum_log_beta(A,B) --> \> beta(A,B,X), { LogX is log(X) }, \< add(LogX).
%accum_bernoulli(A,B) --> \> bernoulli(A,B,X), \< add(X).  

 */

int Bernoulli(RndState *rs,double a,double b) {
	if ((a+b)*Uniform(rs)<b) return 1; else return 0;
}


foreign_t sample_dp_teh( term_t ApSumKX, term_t B, term_t NX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
{
	term_t	N=PL_new_term_ref();
	PL_blob_t *pblob;
	double	apsumkx, b, alphap1;
	double	alpha1=0, alpha2;
	double 	sum_log_wx, sum_sx;
	long		n=0;
	RndState rs;

	int rc = get_double(ApSumKX, &apsumkx)
			&& get_double(B, &b)
			&& get_float_arg(1,p1,&alpha1)
			&& get_state(rnd1,&rs,&pblob);

	alphap1 = alpha1+1;
	sum_log_wx = sum_sx = 0;
	while (rc && PL_get_list(NX,N,NX)) {
		rc = get_long(N,&n);
		sum_log_wx += log(Beta(&rs,alphap1,n));
		sum_sx     += Bernoulli(&rs,alpha1,n);
	}
	alpha2 = Gamma(&rs, apsumkx-sum_sx)/(b-sum_log_wx);

	return rc && PL_unify_term(p2, PL_FUNCTOR, functor_dp1, PL_FLOAT, alpha2)
			&& unify_state(rnd2,&rs,pblob);
}

foreign_t sample_py_teh( term_t ThPrior, term_t DPrior, term_t CountsX, term_t p1, term_t p2, term_t rnd1, term_t rnd2)
{
	PL_blob_t *pblob;
	term_t	Counts = PL_new_term_ref();
	term_t	Count = PL_new_term_ref();
	double	theta_a,  disc_a;
	double	theta_b,  disc_b;
	double	theta1,   disc1;
	double	theta2,   disc2;
	double   theta1_1, disc1_1;
	double 	sum_log_wx, sum_sx, sum_nsx, sum_zx;
	RndState rs;

	int rc = get_float_arg(1,ThPrior,&theta_a)
			&& get_float_arg(2,ThPrior,&theta_b)
			&& get_float_arg(1,DPrior,&disc_a)
			&& get_float_arg(2,DPrior,&disc_b)
			&& get_float_arg(1,p1,&theta1)
			&& get_float_arg(2,p1,&disc1)
			&& get_state(rnd1,&rs,&pblob);

	theta1_1 = theta1+1;
	disc1_1	= 1-disc1;
	sum_log_wx = sum_sx = sum_nsx = sum_zx = 0;
	while (rc && PL_get_list(CountsX,Counts,CountsX)) {
		int n, k, i;
		long c=0;

		for(k=0, n=0; rc && PL_get_list(Counts,Count,Counts); k++, n+=c) {
			rc = get_long(Count,&c);
			if (k>0) { if (Bernoulli(&rs, disc1*k, theta1)) sum_sx++; else sum_nsx++; }
			for (i=0; i<c-1; i++) sum_zx += Bernoulli(&rs, i, disc1_1);
		}
		if (n>1) sum_log_wx += log(Beta(&rs, theta1_1, n-1));
	}

	theta2 = Gamma(&rs, theta_a + sum_sx)/(theta_b-sum_log_wx);
	disc2  = Beta(&rs, disc_a + sum_nsx, disc_b + sum_zx);
	return rc && unify_state(rnd2,&rs,pblob)
			&& PL_unify_term(p2, PL_FUNCTOR, functor_py2, PL_FLOAT, theta2, PL_FLOAT, disc2);
}

/*
foreign_t sum_lengths( term_t Lists, term_t Total)
{
	double total=0;
	size_t len=0;
	term_t List=PL_new_term_ref();
	term_t Tail=PL_new_term_ref();
	int rc=1;

	while (rc && PL_get_list(Lists,List,Lists)) {
		rc = PL_skip_list(List,Tail,&len); 		
		total += len;
	}
	return rc && PL_unify_integer(Total, total);
}
*/