To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / _FullBNT / BNT / CPDs / @discrete_CPD / convert_to_sparse_table.c @ 8:b5b38998ef3b

History | View | Annotate | Download (3.52 KB)

1
/* convert_to_sparse_table.c  convert a sparse discrete CPD with evidence into sparse table */
2
/* convert_to_pot.m located in ../CPDs/discrete_CPD call it */
3
/* 3 input */
4
/* CPD      prhs[0] with 1D sparse CPT */
5
/* domain   prhs[1]                    */
6
/* evidence prhs[2]                    */
7
/* 1 output */
8
/* T        plhs[0] sparse table       */
9

    
10
#include <math.h>
11
#include "mex.h"
12

    
13
void ind_subv(int index, const int *cumprod, const int n, int *bsubv){
14
        int i;
15

    
16
        for (i = n-1; i >= 0; i--) {
17
                bsubv[i] = ((int)floor(index / cumprod[i]));
18
                index = index % cumprod[i];
19
        }
20
}
21

    
22
int subv_ind(const int n, const int *cumprod, const int *subv){
23
        int i, index=0;
24

    
25
        for(i=0; i<n; i++){
26
                index += subv[i] * cumprod[i];
27
        }
28
        return index;
29
}
30

    
31
void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){
32
        double *ptr;
33
        void   *newptr;
34
        int    *ir, *jc;
35
        int    nbytes;
36

    
37
        if(new_nzmax == old_nzmax) return;
38
        nbytes = new_nzmax * sizeof(*ptr);
39
        ptr = mxGetPr(spArray);
40
        newptr = mxRealloc(ptr, nbytes);
41
        mxSetPr(spArray, newptr);
42
        nbytes = new_nzmax * sizeof(*ir);
43
        ir = mxGetIr(spArray);
44
        newptr = mxRealloc(ir, nbytes);
45
        mxSetIr(spArray, newptr);
46
        jc = mxGetJc(spArray);
47
        jc[0] = 0;
48
        jc[1] = new_nzmax;
49
        mxSetNzmax(spArray, new_nzmax);
50
}
51

    
52

    
53
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
54
        int     i, j, NS, NZB, count, bdim, match, domain, bindex, sindex, nzCounts=0;
55
        int     *observed, *bsubv, *ssubv, *bir, *sir, *bjc, *sjc, *mask, *ssize, *bcumprod, *scumprod;
56
        double  *pDomain, *pSize, *bpr, *spr;
57
        mxArray *pTemp;
58

    
59
        pTemp = mxGetField(prhs[0], 0, "CPT");
60
        bpr = mxGetPr(pTemp);
61
        bir = mxGetIr(pTemp);
62
        bjc = mxGetJc(pTemp);
63
        NZB = bjc[1];
64
        pTemp = mxGetField(prhs[0], 0, "sizes");
65
        pSize = mxGetPr(pTemp);
66

    
67
        pDomain = mxGetPr(prhs[1]);
68
        bdim = mxGetNumberOfElements(prhs[1]);
69

    
70
        mask = malloc(bdim * sizeof(int));
71
        ssize = malloc(bdim * sizeof(int));
72
        observed = malloc(bdim * sizeof(int));
73

    
74
        for(i=0; i<bdim; i++){
75
                ssize[i] = (int)pSize[i];
76
        }
77

    
78
        count = 0;
79
        for(i=0; i<bdim; i++){
80
                domain = (int)pDomain[i] - 1;
81
                pTemp = mxGetCell(prhs[2], domain);
82
                if(pTemp){
83
                        mask[count] = i;
84
                        ssize[i] = 1;
85
                        observed[count] = (int)mxGetScalar(pTemp) - 1;
86
                        count++;
87
                }
88
        }
89

    
90
        if(count == 0){
91
                pTemp = mxGetField(prhs[0], 0, "CPT");
92
                plhs[0] = mxDuplicateArray(pTemp);
93
                free(mask);
94
                free(ssize);
95
                free(observed);
96
                return;
97
        }
98

    
99
        bsubv = malloc(bdim * sizeof(int));
100
        ssubv = malloc(count * sizeof(int));
101
        bcumprod = malloc(bdim * sizeof(int));
102
        scumprod = malloc(bdim * sizeof(int));
103

    
104
        NS = 1;
105
        for(i=0; i<bdim; i++){
106
                NS *= ssize[i];
107
        }
108

    
109
        plhs[0] = mxCreateSparse(NS, 1, NS, mxREAL);
110
        spr = mxGetPr(plhs[0]);
111
        sir = mxGetIr(plhs[0]);
112
        sjc = mxGetJc(plhs[0]);
113
        sjc[0] = 0;
114
        sjc[1] = NS;
115

    
116
        bcumprod[0] = 1;
117
        scumprod[0] = 1;
118
        for(i=0; i<bdim-1; i++){
119
                bcumprod[i+1] = bcumprod[i] * (int)pSize[i];
120
                scumprod[i+1] = scumprod[i] * ssize[i];
121
        }
122

    
123
        nzCounts = 0;
124
        for(i=0; i<NZB; i++){
125
                bindex = bir[i];
126
                ind_subv(bindex, bcumprod, bdim, bsubv);
127
                for(j=0; j<count; j++){
128
                        ssubv[j] = bsubv[mask[j]];
129
                }
130
                match = 1;
131
                for(j=0; j<count; j++){
132
                        if((ssubv[j]) != observed[j]){
133
                                match = 0;
134
                                break;
135
                        }
136
                }
137
                if(match){
138
                        spr[nzCounts] = bpr[i];
139
                        sindex = subv_ind(bdim, scumprod, bsubv);
140
                        sir[nzCounts] = sindex;
141
                        nzCounts++;
142
                }
143
        }
144

    
145
        reset_nzmax(plhs[0], NS, nzCounts);
146
        free(mask);
147
        free(ssize);
148
        free(observed);
149
        free(bsubv);
150
        free(ssubv);
151
        free(bcumprod);
152
        free(scumprod);
153
}
154