comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@gibbs_sampling_inf_engine/private/compute_posterior.c @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 #include "mex.h"
2
3 /* Helper function that extracts a one-dimensional slice from a cpt */
4 /*
5 void multiplySlice(mxArray *bnet, mxArray *state, int i, int nsi, int j,
6 mxArray *strides, mxArray *fam, mxArray *cpts,
7 double *y)
8 */
9 void multiplySlice(const mxArray *bnet, const mxArray *state, int i, int nsi, int j,
10 const mxArray *strides, const mxArray *fam, const mxArray *cpts,
11 double *y)
12 {
13 mxArray *ec, *cpt, *family;
14 double *ecElts, *cptElts, *famElts, *strideElts, *ev;
15 int c1, k, famSize, startInd, strideStride, pos, stride;
16
17 strideStride = mxGetM(strides);
18 strideElts = mxGetPr(strides);
19
20 ev = mxGetPr(state);
21
22 /* Get the CPT */
23 ec = mxGetField (bnet, 0, "equiv_class");
24 ecElts = mxGetPr(ec);
25 k = (int) ecElts[j-1];
26 cpt = mxGetCell (cpts, k-1);
27 cptElts = mxGetPr (cpt);
28
29 /* Get the family vector for this cpt */
30 family = mxGetCell (fam, j-1);
31 famSize = mxGetNumberOfElements (family);
32 famElts = mxGetPr (family);
33
34 /* Figure out starting position and stride */
35 startInd = 0;
36 for (c1 = 0, pos = k-1; c1 < famSize; c1++, pos +=strideStride) {
37 if (famElts[c1] != i) {
38 startInd += strideElts[pos]*(ev[(int)famElts[c1]-1]-1);
39 }
40 else {
41 stride = strideElts[pos];
42 }
43 }
44
45 for (c1 = 0, pos = startInd; c1 < nsi; c1++, pos+=stride) {
46 y[c1] *= cptElts[pos];
47 }
48 }
49
50
51 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray
52 *prhs[])
53 {
54 double *pi, *nsElts, *y, *childrenElts;
55 mxArray *ns, *children;
56 double sum;
57 int i, nsi, c1, numChildren;
58
59 pi = mxGetPr(prhs[2]);
60 i = (int) pi[0];
61
62 ns = mxGetField(prhs[0], 0, "node_sizes");
63 nsElts = mxGetPr(ns);
64 nsi = (int) nsElts[i-1];
65
66 /* Initialize the posterior */
67 plhs[0] = mxCreateDoubleMatrix (1, nsi, mxREAL);
68 y = mxGetPr(plhs[0]);
69 for (c1 = 0; c1 < nsi; c1++) {
70 y[c1] = 1;
71 }
72
73 /* Multiply in the cpt of the node i */
74 multiplySlice(prhs[0], prhs[1], i, nsi, i, prhs[3], prhs[4],
75 prhs[6], y);
76
77
78 /* Multiply in cpts of children of i */
79 children = mxGetCell (prhs[5], i-1);
80 numChildren = mxGetNumberOfElements (children);
81 childrenElts = mxGetPr (children);
82
83 for (c1 = 0; c1 < numChildren; c1++) {
84 int j;
85 j = (int) childrenElts[c1];
86 multiplySlice (prhs[0], prhs[1], i, nsi, j, prhs[3], prhs[4],
87 prhs[6], y);
88 }
89
90 sum = 0;
91 /* normalize! */
92 for (c1 = 0; c1 < nsi; c1++) {
93 sum += y[c1];
94 }
95
96 for (c1 = 0; c1 < nsi; c1++) {
97 y[c1] /= sum;
98 }
99 }
100
101
102
103
104
105
106
107