wolffd@0
|
1 /* convert_to_sparse_table.c convert a sparse discrete CPD with evidence into sparse table */
|
wolffd@0
|
2 /* convert_to_pot.m located in ../CPDs/discrete_CPD call it */
|
wolffd@0
|
3 /* 3 input */
|
wolffd@0
|
4 /* CPD prhs[0] with 1D sparse CPT */
|
wolffd@0
|
5 /* domain prhs[1] */
|
wolffd@0
|
6 /* evidence prhs[2] */
|
wolffd@0
|
7 /* 1 output */
|
wolffd@0
|
8 /* T plhs[0] sparse table */
|
wolffd@0
|
9
|
wolffd@0
|
10 #include <math.h>
|
wolffd@0
|
11 #include "mex.h"
|
wolffd@0
|
12
|
wolffd@0
|
13 void ind_subv(int index, const int *cumprod, const int n, int *bsubv){
|
wolffd@0
|
14 int i;
|
wolffd@0
|
15
|
wolffd@0
|
16 for (i = n-1; i >= 0; i--) {
|
wolffd@0
|
17 bsubv[i] = ((int)floor(index / cumprod[i]));
|
wolffd@0
|
18 index = index % cumprod[i];
|
wolffd@0
|
19 }
|
wolffd@0
|
20 }
|
wolffd@0
|
21
|
wolffd@0
|
22 int subv_ind(const int n, const int *cumprod, const int *subv){
|
wolffd@0
|
23 int i, index=0;
|
wolffd@0
|
24
|
wolffd@0
|
25 for(i=0; i<n; i++){
|
wolffd@0
|
26 index += subv[i] * cumprod[i];
|
wolffd@0
|
27 }
|
wolffd@0
|
28 return index;
|
wolffd@0
|
29 }
|
wolffd@0
|
30
|
wolffd@0
|
31 void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){
|
wolffd@0
|
32 double *ptr;
|
wolffd@0
|
33 void *newptr;
|
wolffd@0
|
34 int *ir, *jc;
|
wolffd@0
|
35 int nbytes;
|
wolffd@0
|
36
|
wolffd@0
|
37 if(new_nzmax == old_nzmax) return;
|
wolffd@0
|
38 nbytes = new_nzmax * sizeof(*ptr);
|
wolffd@0
|
39 ptr = mxGetPr(spArray);
|
wolffd@0
|
40 newptr = mxRealloc(ptr, nbytes);
|
wolffd@0
|
41 mxSetPr(spArray, newptr);
|
wolffd@0
|
42 nbytes = new_nzmax * sizeof(*ir);
|
wolffd@0
|
43 ir = mxGetIr(spArray);
|
wolffd@0
|
44 newptr = mxRealloc(ir, nbytes);
|
wolffd@0
|
45 mxSetIr(spArray, newptr);
|
wolffd@0
|
46 jc = mxGetJc(spArray);
|
wolffd@0
|
47 jc[0] = 0;
|
wolffd@0
|
48 jc[1] = new_nzmax;
|
wolffd@0
|
49 mxSetNzmax(spArray, new_nzmax);
|
wolffd@0
|
50 }
|
wolffd@0
|
51
|
wolffd@0
|
52
|
wolffd@0
|
53 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
|
wolffd@0
|
54 int i, j, NS, NZB, count, bdim, match, domain, bindex, sindex, nzCounts=0;
|
wolffd@0
|
55 int *observed, *bsubv, *ssubv, *bir, *sir, *bjc, *sjc, *mask, *ssize, *bcumprod, *scumprod;
|
wolffd@0
|
56 double *pDomain, *pSize, *bpr, *spr;
|
wolffd@0
|
57 mxArray *pTemp;
|
wolffd@0
|
58
|
wolffd@0
|
59 pTemp = mxGetField(prhs[0], 0, "CPT");
|
wolffd@0
|
60 bpr = mxGetPr(pTemp);
|
wolffd@0
|
61 bir = mxGetIr(pTemp);
|
wolffd@0
|
62 bjc = mxGetJc(pTemp);
|
wolffd@0
|
63 NZB = bjc[1];
|
wolffd@0
|
64 pTemp = mxGetField(prhs[0], 0, "sizes");
|
wolffd@0
|
65 pSize = mxGetPr(pTemp);
|
wolffd@0
|
66
|
wolffd@0
|
67 pDomain = mxGetPr(prhs[1]);
|
wolffd@0
|
68 bdim = mxGetNumberOfElements(prhs[1]);
|
wolffd@0
|
69
|
wolffd@0
|
70 mask = malloc(bdim * sizeof(int));
|
wolffd@0
|
71 ssize = malloc(bdim * sizeof(int));
|
wolffd@0
|
72 observed = malloc(bdim * sizeof(int));
|
wolffd@0
|
73
|
wolffd@0
|
74 for(i=0; i<bdim; i++){
|
wolffd@0
|
75 ssize[i] = (int)pSize[i];
|
wolffd@0
|
76 }
|
wolffd@0
|
77
|
wolffd@0
|
78 count = 0;
|
wolffd@0
|
79 for(i=0; i<bdim; i++){
|
wolffd@0
|
80 domain = (int)pDomain[i] - 1;
|
wolffd@0
|
81 pTemp = mxGetCell(prhs[2], domain);
|
wolffd@0
|
82 if(pTemp){
|
wolffd@0
|
83 mask[count] = i;
|
wolffd@0
|
84 ssize[i] = 1;
|
wolffd@0
|
85 observed[count] = (int)mxGetScalar(pTemp) - 1;
|
wolffd@0
|
86 count++;
|
wolffd@0
|
87 }
|
wolffd@0
|
88 }
|
wolffd@0
|
89
|
wolffd@0
|
90 if(count == 0){
|
wolffd@0
|
91 pTemp = mxGetField(prhs[0], 0, "CPT");
|
wolffd@0
|
92 plhs[0] = mxDuplicateArray(pTemp);
|
wolffd@0
|
93 free(mask);
|
wolffd@0
|
94 free(ssize);
|
wolffd@0
|
95 free(observed);
|
wolffd@0
|
96 return;
|
wolffd@0
|
97 }
|
wolffd@0
|
98
|
wolffd@0
|
99 bsubv = malloc(bdim * sizeof(int));
|
wolffd@0
|
100 ssubv = malloc(count * sizeof(int));
|
wolffd@0
|
101 bcumprod = malloc(bdim * sizeof(int));
|
wolffd@0
|
102 scumprod = malloc(bdim * sizeof(int));
|
wolffd@0
|
103
|
wolffd@0
|
104 NS = 1;
|
wolffd@0
|
105 for(i=0; i<bdim; i++){
|
wolffd@0
|
106 NS *= ssize[i];
|
wolffd@0
|
107 }
|
wolffd@0
|
108
|
wolffd@0
|
109 plhs[0] = mxCreateSparse(NS, 1, NS, mxREAL);
|
wolffd@0
|
110 spr = mxGetPr(plhs[0]);
|
wolffd@0
|
111 sir = mxGetIr(plhs[0]);
|
wolffd@0
|
112 sjc = mxGetJc(plhs[0]);
|
wolffd@0
|
113 sjc[0] = 0;
|
wolffd@0
|
114 sjc[1] = NS;
|
wolffd@0
|
115
|
wolffd@0
|
116 bcumprod[0] = 1;
|
wolffd@0
|
117 scumprod[0] = 1;
|
wolffd@0
|
118 for(i=0; i<bdim-1; i++){
|
wolffd@0
|
119 bcumprod[i+1] = bcumprod[i] * (int)pSize[i];
|
wolffd@0
|
120 scumprod[i+1] = scumprod[i] * ssize[i];
|
wolffd@0
|
121 }
|
wolffd@0
|
122
|
wolffd@0
|
123 nzCounts = 0;
|
wolffd@0
|
124 for(i=0; i<NZB; i++){
|
wolffd@0
|
125 bindex = bir[i];
|
wolffd@0
|
126 ind_subv(bindex, bcumprod, bdim, bsubv);
|
wolffd@0
|
127 for(j=0; j<count; j++){
|
wolffd@0
|
128 ssubv[j] = bsubv[mask[j]];
|
wolffd@0
|
129 }
|
wolffd@0
|
130 match = 1;
|
wolffd@0
|
131 for(j=0; j<count; j++){
|
wolffd@0
|
132 if((ssubv[j]) != observed[j]){
|
wolffd@0
|
133 match = 0;
|
wolffd@0
|
134 break;
|
wolffd@0
|
135 }
|
wolffd@0
|
136 }
|
wolffd@0
|
137 if(match){
|
wolffd@0
|
138 spr[nzCounts] = bpr[i];
|
wolffd@0
|
139 sindex = subv_ind(bdim, scumprod, bsubv);
|
wolffd@0
|
140 sir[nzCounts] = sindex;
|
wolffd@0
|
141 nzCounts++;
|
wolffd@0
|
142 }
|
wolffd@0
|
143 }
|
wolffd@0
|
144
|
wolffd@0
|
145 reset_nzmax(plhs[0], NS, nzCounts);
|
wolffd@0
|
146 free(mask);
|
wolffd@0
|
147 free(ssize);
|
wolffd@0
|
148 free(observed);
|
wolffd@0
|
149 free(bsubv);
|
wolffd@0
|
150 free(ssubv);
|
wolffd@0
|
151 free(bcumprod);
|
wolffd@0
|
152 free(scumprod);
|
wolffd@0
|
153 }
|
wolffd@0
|
154
|