Mercurial > hg > camir-aes2014
comparison toolboxes/FullBNT-1.0.7/bnt/CPDs/@discrete_CPD/convert_to_sparse_table.c @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 /* convert_to_sparse_table.c convert a sparse discrete CPD with evidence into sparse table */ | |
2 /* convert_to_pot.m located in ../CPDs/discrete_CPD call it */ | |
3 /* 3 input */ | |
4 /* CPD prhs[0] with 1D sparse CPT */ | |
5 /* domain prhs[1] */ | |
6 /* evidence prhs[2] */ | |
7 /* 1 output */ | |
8 /* T plhs[0] sparse table */ | |
9 | |
10 #include <math.h> | |
11 #include "mex.h" | |
12 | |
13 void ind_subv(int index, const int *cumprod, const int n, int *bsubv){ | |
14 int i; | |
15 | |
16 for (i = n-1; i >= 0; i--) { | |
17 bsubv[i] = ((int)floor(index / cumprod[i])); | |
18 index = index % cumprod[i]; | |
19 } | |
20 } | |
21 | |
22 int subv_ind(const int n, const int *cumprod, const int *subv){ | |
23 int i, index=0; | |
24 | |
25 for(i=0; i<n; i++){ | |
26 index += subv[i] * cumprod[i]; | |
27 } | |
28 return index; | |
29 } | |
30 | |
31 void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){ | |
32 double *ptr; | |
33 void *newptr; | |
34 int *ir, *jc; | |
35 int nbytes; | |
36 | |
37 if(new_nzmax == old_nzmax) return; | |
38 nbytes = new_nzmax * sizeof(*ptr); | |
39 ptr = mxGetPr(spArray); | |
40 newptr = mxRealloc(ptr, nbytes); | |
41 mxSetPr(spArray, newptr); | |
42 nbytes = new_nzmax * sizeof(*ir); | |
43 ir = mxGetIr(spArray); | |
44 newptr = mxRealloc(ir, nbytes); | |
45 mxSetIr(spArray, newptr); | |
46 jc = mxGetJc(spArray); | |
47 jc[0] = 0; | |
48 jc[1] = new_nzmax; | |
49 mxSetNzmax(spArray, new_nzmax); | |
50 } | |
51 | |
52 | |
53 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){ | |
54 int i, j, NS, NZB, count, bdim, match, domain, bindex, sindex, nzCounts=0; | |
55 int *observed, *bsubv, *ssubv, *bir, *sir, *bjc, *sjc, *mask, *ssize, *bcumprod, *scumprod; | |
56 double *pDomain, *pSize, *bpr, *spr; | |
57 mxArray *pTemp; | |
58 | |
59 pTemp = mxGetField(prhs[0], 0, "CPT"); | |
60 bpr = mxGetPr(pTemp); | |
61 bir = mxGetIr(pTemp); | |
62 bjc = mxGetJc(pTemp); | |
63 NZB = bjc[1]; | |
64 pTemp = mxGetField(prhs[0], 0, "sizes"); | |
65 pSize = mxGetPr(pTemp); | |
66 | |
67 pDomain = mxGetPr(prhs[1]); | |
68 bdim = mxGetNumberOfElements(prhs[1]); | |
69 | |
70 mask = malloc(bdim * sizeof(int)); | |
71 ssize = malloc(bdim * sizeof(int)); | |
72 observed = malloc(bdim * sizeof(int)); | |
73 | |
74 for(i=0; i<bdim; i++){ | |
75 ssize[i] = (int)pSize[i]; | |
76 } | |
77 | |
78 count = 0; | |
79 for(i=0; i<bdim; i++){ | |
80 domain = (int)pDomain[i] - 1; | |
81 pTemp = mxGetCell(prhs[2], domain); | |
82 if(pTemp){ | |
83 mask[count] = i; | |
84 ssize[i] = 1; | |
85 observed[count] = (int)mxGetScalar(pTemp) - 1; | |
86 count++; | |
87 } | |
88 } | |
89 | |
90 if(count == 0){ | |
91 pTemp = mxGetField(prhs[0], 0, "CPT"); | |
92 plhs[0] = mxDuplicateArray(pTemp); | |
93 free(mask); | |
94 free(ssize); | |
95 free(observed); | |
96 return; | |
97 } | |
98 | |
99 bsubv = malloc(bdim * sizeof(int)); | |
100 ssubv = malloc(count * sizeof(int)); | |
101 bcumprod = malloc(bdim * sizeof(int)); | |
102 scumprod = malloc(bdim * sizeof(int)); | |
103 | |
104 NS = 1; | |
105 for(i=0; i<bdim; i++){ | |
106 NS *= ssize[i]; | |
107 } | |
108 | |
109 plhs[0] = mxCreateSparse(NS, 1, NS, mxREAL); | |
110 spr = mxGetPr(plhs[0]); | |
111 sir = mxGetIr(plhs[0]); | |
112 sjc = mxGetJc(plhs[0]); | |
113 sjc[0] = 0; | |
114 sjc[1] = NS; | |
115 | |
116 bcumprod[0] = 1; | |
117 scumprod[0] = 1; | |
118 for(i=0; i<bdim-1; i++){ | |
119 bcumprod[i+1] = bcumprod[i] * (int)pSize[i]; | |
120 scumprod[i+1] = scumprod[i] * ssize[i]; | |
121 } | |
122 | |
123 nzCounts = 0; | |
124 for(i=0; i<NZB; i++){ | |
125 bindex = bir[i]; | |
126 ind_subv(bindex, bcumprod, bdim, bsubv); | |
127 for(j=0; j<count; j++){ | |
128 ssubv[j] = bsubv[mask[j]]; | |
129 } | |
130 match = 1; | |
131 for(j=0; j<count; j++){ | |
132 if((ssubv[j]) != observed[j]){ | |
133 match = 0; | |
134 break; | |
135 } | |
136 } | |
137 if(match){ | |
138 spr[nzCounts] = bpr[i]; | |
139 sindex = subv_ind(bdim, scumprod, bsubv); | |
140 sir[nzCounts] = sindex; | |
141 nzCounts++; | |
142 } | |
143 } | |
144 | |
145 reset_nzmax(plhs[0], NS, nzCounts); | |
146 free(mask); | |
147 free(ssize); | |
148 free(observed); | |
149 free(bsubv); | |
150 free(ssubv); | |
151 free(bcumprod); | |
152 free(scumprod); | |
153 } | |
154 |