To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.
root / _FullBNT / BNT / CPDs / @discrete_CPD / convert_to_sparse_table.c @ 8:b5b38998ef3b
History | View | Annotate | Download (3.52 KB)
| 1 |
/* convert_to_sparse_table.c convert a sparse discrete CPD with evidence into sparse table */
|
|---|---|
| 2 |
/* convert_to_pot.m located in ../CPDs/discrete_CPD call it */
|
| 3 |
/* 3 input */
|
| 4 |
/* CPD prhs[0] with 1D sparse CPT */
|
| 5 |
/* domain prhs[1] */
|
| 6 |
/* evidence prhs[2] */
|
| 7 |
/* 1 output */
|
| 8 |
/* T plhs[0] sparse table */
|
| 9 |
|
| 10 |
#include <math.h> |
| 11 |
#include "mex.h" |
| 12 |
|
| 13 |
void ind_subv(int index, const int *cumprod, const int n, int *bsubv){ |
| 14 |
int i;
|
| 15 |
|
| 16 |
for (i = n-1; i >= 0; i--) { |
| 17 |
bsubv[i] = ((int)floor(index / cumprod[i]));
|
| 18 |
index = index % cumprod[i]; |
| 19 |
} |
| 20 |
} |
| 21 |
|
| 22 |
int subv_ind(const int n, const int *cumprod, const int *subv){ |
| 23 |
int i, index=0; |
| 24 |
|
| 25 |
for(i=0; i<n; i++){ |
| 26 |
index += subv[i] * cumprod[i]; |
| 27 |
} |
| 28 |
return index;
|
| 29 |
} |
| 30 |
|
| 31 |
void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){ |
| 32 |
double *ptr;
|
| 33 |
void *newptr;
|
| 34 |
int *ir, *jc;
|
| 35 |
int nbytes;
|
| 36 |
|
| 37 |
if(new_nzmax == old_nzmax) return; |
| 38 |
nbytes = new_nzmax * sizeof(*ptr);
|
| 39 |
ptr = mxGetPr(spArray); |
| 40 |
newptr = mxRealloc(ptr, nbytes); |
| 41 |
mxSetPr(spArray, newptr); |
| 42 |
nbytes = new_nzmax * sizeof(*ir);
|
| 43 |
ir = mxGetIr(spArray); |
| 44 |
newptr = mxRealloc(ir, nbytes); |
| 45 |
mxSetIr(spArray, newptr); |
| 46 |
jc = mxGetJc(spArray); |
| 47 |
jc[0] = 0; |
| 48 |
jc[1] = new_nzmax;
|
| 49 |
mxSetNzmax(spArray, new_nzmax); |
| 50 |
} |
| 51 |
|
| 52 |
|
| 53 |
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){ |
| 54 |
int i, j, NS, NZB, count, bdim, match, domain, bindex, sindex, nzCounts=0; |
| 55 |
int *observed, *bsubv, *ssubv, *bir, *sir, *bjc, *sjc, *mask, *ssize, *bcumprod, *scumprod;
|
| 56 |
double *pDomain, *pSize, *bpr, *spr;
|
| 57 |
mxArray *pTemp; |
| 58 |
|
| 59 |
pTemp = mxGetField(prhs[0], 0, "CPT"); |
| 60 |
bpr = mxGetPr(pTemp); |
| 61 |
bir = mxGetIr(pTemp); |
| 62 |
bjc = mxGetJc(pTemp); |
| 63 |
NZB = bjc[1];
|
| 64 |
pTemp = mxGetField(prhs[0], 0, "sizes"); |
| 65 |
pSize = mxGetPr(pTemp); |
| 66 |
|
| 67 |
pDomain = mxGetPr(prhs[1]);
|
| 68 |
bdim = mxGetNumberOfElements(prhs[1]);
|
| 69 |
|
| 70 |
mask = malloc(bdim * sizeof(int)); |
| 71 |
ssize = malloc(bdim * sizeof(int)); |
| 72 |
observed = malloc(bdim * sizeof(int)); |
| 73 |
|
| 74 |
for(i=0; i<bdim; i++){ |
| 75 |
ssize[i] = (int)pSize[i];
|
| 76 |
} |
| 77 |
|
| 78 |
count = 0;
|
| 79 |
for(i=0; i<bdim; i++){ |
| 80 |
domain = (int)pDomain[i] - 1; |
| 81 |
pTemp = mxGetCell(prhs[2], domain);
|
| 82 |
if(pTemp){
|
| 83 |
mask[count] = i; |
| 84 |
ssize[i] = 1;
|
| 85 |
observed[count] = (int)mxGetScalar(pTemp) - 1; |
| 86 |
count++; |
| 87 |
} |
| 88 |
} |
| 89 |
|
| 90 |
if(count == 0){ |
| 91 |
pTemp = mxGetField(prhs[0], 0, "CPT"); |
| 92 |
plhs[0] = mxDuplicateArray(pTemp);
|
| 93 |
free(mask); |
| 94 |
free(ssize); |
| 95 |
free(observed); |
| 96 |
return;
|
| 97 |
} |
| 98 |
|
| 99 |
bsubv = malloc(bdim * sizeof(int)); |
| 100 |
ssubv = malloc(count * sizeof(int)); |
| 101 |
bcumprod = malloc(bdim * sizeof(int)); |
| 102 |
scumprod = malloc(bdim * sizeof(int)); |
| 103 |
|
| 104 |
NS = 1;
|
| 105 |
for(i=0; i<bdim; i++){ |
| 106 |
NS *= ssize[i]; |
| 107 |
} |
| 108 |
|
| 109 |
plhs[0] = mxCreateSparse(NS, 1, NS, mxREAL); |
| 110 |
spr = mxGetPr(plhs[0]);
|
| 111 |
sir = mxGetIr(plhs[0]);
|
| 112 |
sjc = mxGetJc(plhs[0]);
|
| 113 |
sjc[0] = 0; |
| 114 |
sjc[1] = NS;
|
| 115 |
|
| 116 |
bcumprod[0] = 1; |
| 117 |
scumprod[0] = 1; |
| 118 |
for(i=0; i<bdim-1; i++){ |
| 119 |
bcumprod[i+1] = bcumprod[i] * (int)pSize[i]; |
| 120 |
scumprod[i+1] = scumprod[i] * ssize[i];
|
| 121 |
} |
| 122 |
|
| 123 |
nzCounts = 0;
|
| 124 |
for(i=0; i<NZB; i++){ |
| 125 |
bindex = bir[i]; |
| 126 |
ind_subv(bindex, bcumprod, bdim, bsubv); |
| 127 |
for(j=0; j<count; j++){ |
| 128 |
ssubv[j] = bsubv[mask[j]]; |
| 129 |
} |
| 130 |
match = 1;
|
| 131 |
for(j=0; j<count; j++){ |
| 132 |
if((ssubv[j]) != observed[j]){
|
| 133 |
match = 0;
|
| 134 |
break;
|
| 135 |
} |
| 136 |
} |
| 137 |
if(match){
|
| 138 |
spr[nzCounts] = bpr[i]; |
| 139 |
sindex = subv_ind(bdim, scumprod, bsubv); |
| 140 |
sir[nzCounts] = sindex; |
| 141 |
nzCounts++; |
| 142 |
} |
| 143 |
} |
| 144 |
|
| 145 |
reset_nzmax(plhs[0], NS, nzCounts);
|
| 146 |
free(mask); |
| 147 |
free(ssize); |
| 148 |
free(observed); |
| 149 |
free(bsubv); |
| 150 |
free(ssubv); |
| 151 |
free(bcumprod); |
| 152 |
free(scumprod); |
| 153 |
} |
| 154 |
|