annotate toolboxes/FullBNT-1.0.7/bnt/inference/static/@gibbs_sampling_inf_engine/private/compute_posterior.c @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
rev   line source
wolffd@0 1 #include "mex.h"
wolffd@0 2
wolffd@0 3 /* Helper function that extracts a one-dimensional slice from a cpt */
wolffd@0 4 /*
wolffd@0 5 void multiplySlice(mxArray *bnet, mxArray *state, int i, int nsi, int j,
wolffd@0 6 mxArray *strides, mxArray *fam, mxArray *cpts,
wolffd@0 7 double *y)
wolffd@0 8 */
wolffd@0 9 void multiplySlice(const mxArray *bnet, const mxArray *state, int i, int nsi, int j,
wolffd@0 10 const mxArray *strides, const mxArray *fam, const mxArray *cpts,
wolffd@0 11 double *y)
wolffd@0 12 {
wolffd@0 13 mxArray *ec, *cpt, *family;
wolffd@0 14 double *ecElts, *cptElts, *famElts, *strideElts, *ev;
wolffd@0 15 int c1, k, famSize, startInd, strideStride, pos, stride;
wolffd@0 16
wolffd@0 17 strideStride = mxGetM(strides);
wolffd@0 18 strideElts = mxGetPr(strides);
wolffd@0 19
wolffd@0 20 ev = mxGetPr(state);
wolffd@0 21
wolffd@0 22 /* Get the CPT */
wolffd@0 23 ec = mxGetField (bnet, 0, "equiv_class");
wolffd@0 24 ecElts = mxGetPr(ec);
wolffd@0 25 k = (int) ecElts[j-1];
wolffd@0 26 cpt = mxGetCell (cpts, k-1);
wolffd@0 27 cptElts = mxGetPr (cpt);
wolffd@0 28
wolffd@0 29 /* Get the family vector for this cpt */
wolffd@0 30 family = mxGetCell (fam, j-1);
wolffd@0 31 famSize = mxGetNumberOfElements (family);
wolffd@0 32 famElts = mxGetPr (family);
wolffd@0 33
wolffd@0 34 /* Figure out starting position and stride */
wolffd@0 35 startInd = 0;
wolffd@0 36 for (c1 = 0, pos = k-1; c1 < famSize; c1++, pos +=strideStride) {
wolffd@0 37 if (famElts[c1] != i) {
wolffd@0 38 startInd += strideElts[pos]*(ev[(int)famElts[c1]-1]-1);
wolffd@0 39 }
wolffd@0 40 else {
wolffd@0 41 stride = strideElts[pos];
wolffd@0 42 }
wolffd@0 43 }
wolffd@0 44
wolffd@0 45 for (c1 = 0, pos = startInd; c1 < nsi; c1++, pos+=stride) {
wolffd@0 46 y[c1] *= cptElts[pos];
wolffd@0 47 }
wolffd@0 48 }
wolffd@0 49
wolffd@0 50
wolffd@0 51 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray
wolffd@0 52 *prhs[])
wolffd@0 53 {
wolffd@0 54 double *pi, *nsElts, *y, *childrenElts;
wolffd@0 55 mxArray *ns, *children;
wolffd@0 56 double sum;
wolffd@0 57 int i, nsi, c1, numChildren;
wolffd@0 58
wolffd@0 59 pi = mxGetPr(prhs[2]);
wolffd@0 60 i = (int) pi[0];
wolffd@0 61
wolffd@0 62 ns = mxGetField(prhs[0], 0, "node_sizes");
wolffd@0 63 nsElts = mxGetPr(ns);
wolffd@0 64 nsi = (int) nsElts[i-1];
wolffd@0 65
wolffd@0 66 /* Initialize the posterior */
wolffd@0 67 plhs[0] = mxCreateDoubleMatrix (1, nsi, mxREAL);
wolffd@0 68 y = mxGetPr(plhs[0]);
wolffd@0 69 for (c1 = 0; c1 < nsi; c1++) {
wolffd@0 70 y[c1] = 1;
wolffd@0 71 }
wolffd@0 72
wolffd@0 73 /* Multiply in the cpt of the node i */
wolffd@0 74 multiplySlice(prhs[0], prhs[1], i, nsi, i, prhs[3], prhs[4],
wolffd@0 75 prhs[6], y);
wolffd@0 76
wolffd@0 77
wolffd@0 78 /* Multiply in cpts of children of i */
wolffd@0 79 children = mxGetCell (prhs[5], i-1);
wolffd@0 80 numChildren = mxGetNumberOfElements (children);
wolffd@0 81 childrenElts = mxGetPr (children);
wolffd@0 82
wolffd@0 83 for (c1 = 0; c1 < numChildren; c1++) {
wolffd@0 84 int j;
wolffd@0 85 j = (int) childrenElts[c1];
wolffd@0 86 multiplySlice (prhs[0], prhs[1], i, nsi, j, prhs[3], prhs[4],
wolffd@0 87 prhs[6], y);
wolffd@0 88 }
wolffd@0 89
wolffd@0 90 sum = 0;
wolffd@0 91 /* normalize! */
wolffd@0 92 for (c1 = 0; c1 < nsi; c1++) {
wolffd@0 93 sum += y[c1];
wolffd@0 94 }
wolffd@0 95
wolffd@0 96 for (c1 = 0; c1 < nsi; c1++) {
wolffd@0 97 y[c1] /= sum;
wolffd@0 98 }
wolffd@0 99 }
wolffd@0 100
wolffd@0 101
wolffd@0 102
wolffd@0 103
wolffd@0 104
wolffd@0 105
wolffd@0 106
wolffd@0 107