comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@jtree_sparse_inf_engine/collect_evidence.c @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 /* C mex for collect_evidence.c in @jtree_sparse_inf_engine directory */
2 /* File enter_evidence.m in directory @jtree_sparse_inf_engine call it*/
3
4 /******************************************/
5 /* collect_evidence has 3 input & 2 output*/
6 /* engine */
7 /* clpot */
8 /* seppot */
9 /* */
10 /* clpot */
11 /* seppot */
12 /******************************************/
13
14 #include <math.h>
15 #include <stdlib.h>
16 #include "mex.h"
17
18 int compare(const void* src1, const void* src2){
19 int i1 = *(int*)src1 ;
20 int i2 = *(int*)src2 ;
21 return i1-i2 ;
22 }
23
24 void ind_subv(int index, const int *cumprod, int n, int *bsubv){
25 int i;
26
27 for (i = n-1; i >= 0; i--) {
28 bsubv[i] = ((int)floor(index / cumprod[i]));
29 index = index % cumprod[i];
30 }
31 }
32
33 int subv_ind(const int n, const int *cumprod, const int *subv){
34 int i, index=0;
35
36 for(i=0; i<n; i++){
37 index += subv[i] * cumprod[i];
38 }
39 return index;
40 }
41
42 void compute_fixed_weight(int *weight, const double *pbSize, const int *dmask, const int *bCumprod, const int ND, const int diffdim){
43 int i, j;
44 int *eff_cumprod, *subv, *diffsize, *diff_cumprod;
45
46 subv = malloc(diffdim * sizeof(int));
47 eff_cumprod = malloc(diffdim * sizeof(int));
48 diffsize = malloc(diffdim * sizeof(int));
49 diff_cumprod = malloc(diffdim * sizeof(int));
50 for(i=0; i<diffdim; i++){
51 eff_cumprod[i] = bCumprod[dmask[i]];
52 diffsize[i] = (int)pbSize[dmask[i]];
53 }
54 diff_cumprod[0] = 1;
55 for(i=0; i<diffdim-1; i++){
56 diff_cumprod[i+1] = diff_cumprod[i] * diffsize[i];
57 }
58 for(i=0; i<ND; i++){
59 ind_subv(i, diff_cumprod, diffdim, subv);
60 weight[i] = 0;
61 for(j=0; j<diffdim; j++){
62 weight[i] += eff_cumprod[j] * subv[j];
63 }
64 }
65 free(eff_cumprod);
66 free(subv);
67 free(diffsize);
68 free(diff_cumprod);
69 }
70
71 void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){
72 double *ptr;
73 void *newptr;
74 int *ir, *jc;
75 int nbytes;
76
77 if(new_nzmax == old_nzmax) return;
78 nbytes = new_nzmax * sizeof(*ptr);
79 ptr = mxGetPr(spArray);
80 newptr = mxRealloc(ptr, nbytes);
81 mxSetPr(spArray, newptr);
82 nbytes = new_nzmax * sizeof(*ir);
83 ir = mxGetIr(spArray);
84 newptr = mxRealloc(ir, nbytes);
85 mxSetIr(spArray, newptr);
86 jc = mxGetJc(spArray);
87 jc[0] = 0;
88 jc[1] = new_nzmax;
89 mxSetNzmax(spArray, new_nzmax);
90 }
91
92 mxArray* convert_ill_table_to_sparse(const double *Table, const int *sequence, const int nzCounts, const int N){
93 mxArray *spTable;
94 int i, temp, *irs, *jcs, count=0;
95 double *sr;
96
97 spTable = mxCreateSparse(N, 1, nzCounts, mxREAL);
98 sr = mxGetPr(spTable);
99 irs = mxGetIr(spTable);
100 jcs = mxGetJc(spTable);
101
102 jcs[0] = 0;
103 jcs[1] = nzCounts;
104
105 for(i=0; i<nzCounts; i++){
106 irs[i] = sequence[count];
107 count++;
108 temp = sequence[count];
109 sr[i] = Table[temp];
110 count++;
111 }
112 return spTable;
113 }
114
115 void multiply_null_by_spPot(mxArray *bigPot, const mxArray *smallPot){
116 int i, j, count, count1, match, temp, bdim, sdim, diffdim, NB, NS, ND, NZB, NZS, bindex, sindex, nzCounts=0;
117 int *samemask, *diffmask, *sir, *sjc, *bCumprod, *sCumprod, *ssubv, *sequence, *weight;
118 double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *spr, *bpr;
119 mxArray *pTemp, *pTemp1;
120
121 pTemp = mxGetField(bigPot, 0, "domain");
122 pbDomain = mxGetPr(pTemp);
123 bdim = mxGetNumberOfElements(pTemp);
124 pTemp = mxGetField(smallPot, 0, "domain");
125 psDomain = mxGetPr(pTemp);
126 sdim = mxGetNumberOfElements(pTemp);
127
128 pTemp = mxGetField(bigPot, 0, "sizes");
129 pbSize = mxGetPr(pTemp);
130 pTemp = mxGetField(smallPot, 0, "sizes");
131 psSize = mxGetPr(pTemp);
132
133 pTemp = mxGetField(smallPot, 0, "T");
134 spr = mxGetPr(pTemp);
135 sir = mxGetIr(pTemp);
136 sjc = mxGetJc(pTemp);
137 NZS = sjc[1];
138
139 NB = 1;
140 for(i=0; i<bdim; i++){
141 NB *= (int)pbSize[i];
142 }
143
144 if(sdim == 0){
145 pTemp = mxCreateSparse(NB, 1, NB, mxREAL);
146 mxSetField(bigPot, 0, "T", pTemp);
147 bpr = mxGetPr(pTemp);
148 sir = mxGetIr(pTemp);
149 sjc = mxGetJc(pTemp);
150 sjc[0] = 0;
151 sjc[1] = NB;
152 for(i=0; i<NB; i++){
153 bpr[i] = *spr;
154 sir[i] = i;
155 }
156 return;
157 }
158
159 NS = 1;
160 for(i=0; i<sdim; i++){
161 NS *= (int)psSize[i];
162 }
163 ND = NB / NS;
164
165 if(ND == 1){
166 pTemp1 = mxGetField(smallPot, 0, "T");
167 pTemp = mxDuplicateArray(pTemp1);
168 mxSetField(bigPot, 0, "T", pTemp);
169 return;
170 }
171
172
173 NZB = ND * NZS;
174
175 diffdim = bdim - sdim;
176 sequence = malloc(NZB * 2 * sizeof(int));
177 bigTable = malloc(NZB * sizeof(double));
178 samemask = malloc(sdim * sizeof(int));
179 diffmask = malloc(diffdim * sizeof(int));
180 bCumprod = malloc(bdim * sizeof(int));
181 sCumprod = malloc(sdim * sizeof(int));
182 weight = malloc(ND * sizeof(int));
183 ssubv = malloc(sdim * sizeof(int));
184
185 count = 0;
186 count1 = 0;
187 for(i=0; i<bdim; i++){
188 match = 0;
189 for(j=0; j<sdim; j++){
190 if(pbDomain[i] == psDomain[j]){
191 samemask[count] = i;
192 match = 1;
193 count++;
194 break;
195 }
196 }
197 if(match == 0){
198 diffmask[count1] = i;
199 count1++;
200 }
201 }
202
203 bCumprod[0] = 1;
204 for(i=0; i<bdim-1; i++){
205 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
206 }
207 sCumprod[0] = 1;
208 for(i=0; i<sdim-1; i++){
209 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
210 }
211
212 count = 0;
213 compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
214 for(i=0; i<NZS; i++){
215 sindex = sir[i];
216 ind_subv(sindex, sCumprod, sdim, ssubv);
217 temp = 0;
218 for(j=0; j<sdim; j++){
219 temp += ssubv[j] * bCumprod[samemask[j]];
220 }
221 for(j=0; j<ND; j++){
222 bindex = weight[j] + temp;
223 bigTable[nzCounts] = spr[i];
224 sequence[count] = bindex;
225 count++;
226 sequence[count] = nzCounts;
227 nzCounts++;
228 count++;
229 }
230 }
231
232 pTemp = mxGetField(bigPot, 0, "T");
233 if(pTemp)mxDestroyArray(pTemp);
234 qsort(sequence, nzCounts, sizeof(int) * 2, compare);
235 pTemp = convert_ill_table_to_sparse(bigTable, sequence, nzCounts, NB);
236 mxSetField(bigPot, 0, "T", pTemp);
237
238 free(sequence);
239 free(bigTable);
240 free(samemask);
241 free(diffmask);
242 free(bCumprod);
243 free(sCumprod);
244 free(weight);
245 free(ssubv);
246 }
247
248 void multiply_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
249 int i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex, nzCounts=0;
250 int *mask, *result, *bir, *sir, *rir, *bjc, *sjc, *rjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
251 double *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, *rpr;
252 mxArray *pTemp, *pTemp1;
253
254 pTemp = mxGetField(bigPot, 0, "domain");
255 pbDomain = mxGetPr(pTemp);
256 bdim = mxGetNumberOfElements(pTemp);
257 pTemp = mxGetField(smallPot, 0, "domain");
258 psDomain = mxGetPr(pTemp);
259 sdim = mxGetNumberOfElements(pTemp);
260
261 pTemp = mxGetField(bigPot, 0, "sizes");
262 pbSize = mxGetPr(pTemp);
263 pTemp = mxGetField(smallPot, 0, "sizes");
264 psSize = mxGetPr(pTemp);
265
266 NB = 1;
267 for(i=0; i<bdim; i++){
268 NB *= (int)pbSize[i];
269 }
270
271 pTemp = mxGetField(bigPot, 0, "T");
272 bpr = mxGetPr(pTemp);
273 bir = mxGetIr(pTemp);
274 bjc = mxGetJc(pTemp);
275 NZB = bjc[1];
276
277 pTemp = mxGetField(smallPot, 0, "T");
278 spr = mxGetPr(pTemp);
279 sir = mxGetIr(pTemp);
280 sjc = mxGetJc(pTemp);
281 NZS = sjc[1];
282
283 if(sdim == 0){
284 for(i=0; i<NZB; i++){
285 bpr[i] *= *spr;
286 }
287 return;
288 }
289
290 pTemp1 = mxCreateSparse(NB, 1, NZB, mxREAL);
291 rpr = mxGetPr(pTemp1);
292 rir = mxGetIr(pTemp1);
293 rjc = mxGetJc(pTemp1);
294 rjc[0] = 0;
295 rjc[1] = NZB;
296
297 mask = malloc(sdim * sizeof(int));
298 bCumprod = malloc(bdim * sizeof(int));
299 sCumprod = malloc(sdim * sizeof(int));
300 bsubv = malloc(bdim * sizeof(int));
301 ssubv = malloc(sdim * sizeof(int));
302
303 count = 0;
304 for(i=0; i<sdim; i++){
305 for(j=0; j<bdim; j++){
306 if(psDomain[i] == pbDomain[j]){
307 mask[count] = j;
308 count++;
309 break;
310 }
311 }
312 }
313
314 bCumprod[0] = 1;
315 for(i=0; i<bdim-1; i++){
316 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
317 }
318 sCumprod[0] = 1;
319 for(i=0; i<sdim-1; i++){
320 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
321 }
322
323 for(i=0; i<NZB; i++){
324 bindex = bir[i];
325 ind_subv(bindex, bCumprod, bdim, bsubv);
326 for(j=0; j<sdim; j++){
327 ssubv[j] = bsubv[mask[j]];
328 }
329 sindex = subv_ind(sdim, sCumprod, ssubv);
330 result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
331 if(result){
332 position = result - sir;
333 rpr[nzCounts] = bpr[i] * spr[position];
334 rir[nzCounts] = bindex;
335 nzCounts++;
336 }
337 }
338
339 pTemp = mxGetField(bigPot, 0, "T");
340 if(pTemp)mxDestroyArray(pTemp);
341 reset_nzmax(pTemp1, NZB, nzCounts);
342 mxSetField(bigPot, 0, "T", pTemp1);
343
344 free(mask);
345 free(bCumprod);
346 free(sCumprod);
347 free(bsubv);
348 free(ssubv);
349 }
350
351 mxArray* marginal_null_to_spPot(const mxArray *bigPot, const mxArray *sDomain, const int maximize){
352 int i, j, count, bdim, sdim, NB, NS, ND;
353 int *mask, *sir, *sjc;
354 double *pbDomain, *psDomain, *pbSize, *psSize, *spr;
355 mxArray *pTemp, *smallPot;
356 const char *field_names[] = {"domain", "T", "sizes"};
357
358 pTemp = mxGetField(bigPot, 0, "domain");
359 pbDomain = mxGetPr(pTemp);
360 bdim = mxGetNumberOfElements(pTemp);
361 psDomain = mxGetPr(sDomain);
362 sdim = mxGetNumberOfElements(sDomain);
363 pTemp = mxGetField(bigPot, 0, "sizes");
364 pbSize = mxGetPr(pTemp);
365
366 smallPot = mxCreateStructMatrix(1, 1, 3, field_names);
367 pTemp = mxDuplicateArray(sDomain);
368 mxSetField(smallPot, 0, "domain", pTemp);
369
370 NB = 1;
371 for(i=0; i<bdim; i++){
372 NB *= (int)pbSize[i];
373 }
374
375 if(sdim == 0){
376 pTemp = mxCreateSparse(1, 1, 1, mxREAL);
377 mxSetField(smallPot, 0, "T", pTemp);
378 spr = mxGetPr(pTemp);
379 sir = mxGetIr(pTemp);
380 sjc = mxGetJc(pTemp);
381 *spr = 0;
382 *sir = 0;
383 sjc[0] = 0;
384 sjc[1] = 1;
385 if(maximize) *spr = 1;
386 else *spr = NB;
387
388 pTemp = mxCreateDoubleMatrix(1, 1, mxREAL);
389 *mxGetPr(pTemp) = 1;
390 mxSetField(smallPot, 0, "sizes", pTemp);
391 return smallPot;
392 }
393
394 mask = malloc(sdim * sizeof(int));
395 count = 0;
396 for(i=0; i<sdim; i++){
397 for(j=0; j<bdim; j++){
398 if(psDomain[i] == pbDomain[j]){
399 mask[count] = j;
400 count++;
401 break;
402 }
403 }
404 }
405 pTemp = mxCreateDoubleMatrix(1, count, mxREAL);
406 psSize = mxGetPr(pTemp);
407 NS = 1;
408 for(i=0; i<count; i++){
409 psSize[i] = pbSize[mask[i]];
410 NS *= (int)psSize[i];
411 }
412 mxSetField(smallPot, 0, "sizes", pTemp);
413
414 ND = NB / NS;
415
416 pTemp = mxCreateSparse(NS, 1, NS, mxREAL);
417 mxSetField(smallPot, 0, "T", pTemp);
418 spr = mxGetPr(pTemp);
419 sir = mxGetIr(pTemp);
420 sjc = mxGetJc(pTemp);
421 if(maximize){
422 for(i=0; i<NS; i++){
423 spr[i] = 1;
424 sir[i] = i;
425 }
426 }
427 else{
428 for(i=0; i<NS; i++){
429 spr[i] = ND;
430 sir[i] = i;
431 }
432 }
433 sjc[0] = 0;
434 sjc[1] = NS;
435
436 free(mask);
437 return smallPot;
438 }
439
440 mxArray* marginal_spPot_to_spPot(const mxArray *bigPot, const mxArray *sDomain, const int maximize){
441 int i, j, count, bdim, sdim, NB, NS, NZB, position, bindex, sindex, nzCounts=0;
442 int *mask, *sequence, *result, *bir, *bjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
443 double *sTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr;
444 mxArray *pTemp, *smallPot;
445 const char *field_names[] = {"domain", "T", "sizes"};
446
447 pTemp = mxGetField(bigPot, 0, "domain");
448 pbDomain = mxGetPr(pTemp);
449 bdim = mxGetNumberOfElements(pTemp);
450 psDomain = mxGetPr(sDomain);
451 sdim = mxGetNumberOfElements(sDomain);
452 pTemp = mxGetField(bigPot, 0, "sizes");
453 pbSize = mxGetPr(pTemp);
454
455 pTemp = mxGetField(bigPot, 0, "T");
456 bpr = mxGetPr(pTemp);
457 bir = mxGetIr(pTemp);
458 bjc = mxGetJc(pTemp);
459 NZB = bjc[1];
460
461 smallPot = mxCreateStructMatrix(1, 1, 3, field_names);
462 pTemp = mxDuplicateArray(sDomain);
463 mxSetField(smallPot, 0, "domain", pTemp);
464
465 if(sdim == 0){
466 pTemp = mxCreateSparse(1, 1, 1, mxREAL);
467 mxSetField(smallPot, 0, "T", pTemp);
468 spr = mxGetPr(pTemp);
469 bir = mxGetIr(pTemp);
470 bjc = mxGetJc(pTemp);
471 *spr = 0;
472 *bir = 0;
473 bjc[0] = 0;
474 bjc[1] = 1;
475 if(maximize){
476 for(i=0; i<NZB; i++){
477 *spr = (*spr < bpr[i])? bpr[i] : *spr;
478 }
479 }
480 else{
481 for(i=0; i<NZB; i++){
482 *spr += bpr[i];
483 }
484 }
485
486 pTemp = mxCreateDoubleMatrix(1, 1, mxREAL);
487 *mxGetPr(pTemp) = 1;
488 mxSetField(smallPot, 0, "sizes", pTemp);
489 return smallPot;
490 }
491
492 NB = 1;
493 for(i=0; i<bdim; i++){
494 NB *= (int)pbSize[i];
495 }
496
497 mask = malloc(sdim * sizeof(int));
498 count = 0;
499 for(i=0; i<sdim; i++){
500 for(j=0; j<bdim; j++){
501 if(psDomain[i] == pbDomain[j]){
502 mask[count] = j;
503 count++;
504 break;
505 }
506 }
507 }
508 pTemp = mxCreateDoubleMatrix(1, count, mxREAL);
509 psSize = mxGetPr(pTemp);
510 NS = 1;
511 for(i=0; i<count; i++){
512 psSize[i] = pbSize[mask[i]];
513 NS *= (int)psSize[i];
514 }
515 mxSetField(smallPot, 0, "sizes", pTemp);
516
517
518 sTable = malloc(NZB * sizeof(double));
519 sequence = malloc(NZB * 2 * sizeof(double));
520 bCumprod = malloc(bdim * sizeof(int));
521 sCumprod = malloc(sdim * sizeof(int));
522 bsubv = malloc(bdim * sizeof(int));
523 ssubv = malloc(sdim * sizeof(int));
524
525 for(i=0; i<NZB; i++)sTable[i] = 0;
526
527 bCumprod[0] = 1;
528 for(i=0; i<bdim-1; i++){
529 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
530 }
531 sCumprod[0] = 1;
532 for(i=0; i<sdim-1; i++){
533 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
534 }
535
536 count = 0;
537 for(i=0; i<NZB; i++){
538 bindex = bir[i];
539 ind_subv(bindex, bCumprod, bdim, bsubv);
540 for(j=0; j<sdim; j++){
541 ssubv[j] = bsubv[mask[j]];
542 }
543 sindex = subv_ind(sdim, sCumprod, ssubv);
544 result = (int *) bsearch(&sindex, sequence, nzCounts, sizeof(int)*2, compare);
545 if(result){
546 position = (result - sequence) / 2;
547 if(maximize)
548 sTable[position] = (sTable[position] < bpr[i]) ? bpr[i] : sTable[position];
549 else sTable[position] += bpr[i];
550 }
551 else {
552 if(maximize)
553 sTable[nzCounts] = (sTable[nzCounts] < bpr[i]) ? bpr[i] : sTable[nzCounts];
554 else sTable[nzCounts] += bpr[i];
555 sequence[count] = sindex;
556 count++;
557 sequence[count] = nzCounts;
558 nzCounts++;
559 count++;
560 }
561 }
562
563 qsort(sequence, nzCounts, sizeof(int) * 2, compare);
564 pTemp = convert_ill_table_to_sparse(sTable, sequence, nzCounts, NS);
565 mxSetField(smallPot, 0, "T", pTemp);
566
567 free(sTable);
568 free(sequence);
569 free(mask);
570 free(bCumprod);
571 free(sCumprod);
572 free(bsubv);
573 free(ssubv);
574
575 return smallPot;
576 }
577
578
579 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
580 int i, n, p, np, pn, loop, loops, nCliques, temp, maximize;
581 int *collect_order;
582 double *pr, *pr1;
583 mxArray *pTemp, *pTemp1, *pPostP, *pClpot, *pSeppot, *pSeparator;
584
585 pTemp = mxGetField(prhs[0], 0, "cliques");
586 nCliques = mxGetNumberOfElements(pTemp);
587 loops = nCliques - 1;
588 pTemp = mxGetField(prhs[0], 0, "maximize");
589 maximize = (int)mxGetScalar(pTemp);
590 pSeparator = mxGetField(prhs[0], 0, "separator");
591
592 collect_order = malloc(2 * loops * sizeof(int));
593
594 pTemp = mxGetField(prhs[0], 0, "postorder");
595 pr = mxGetPr(pTemp);
596 pPostP = mxGetField(prhs[0], 0, "postorder_parents");
597 for(i=0; i<loops; i++){
598 temp = (int)pr[i] - 1;
599 pTemp = mxGetCell(pPostP, temp);
600 pr1 = mxGetPr(pTemp);
601 collect_order[i] = (int)pr1[0] - 1;
602 collect_order[i+loops] = temp;
603 }
604
605 plhs[0] = mxDuplicateArray(prhs[1]);
606 plhs[1] = mxDuplicateArray(prhs[2]);
607
608 for(loop=0; loop<loops; loop++){
609 p = collect_order[loop];
610 n = collect_order[loop+loops];
611 np = p * nCliques + n;
612 pn = n * nCliques + p;
613 pClpot = mxGetCell(plhs[0], n);
614 pTemp1 = mxGetField(pClpot, 0, "T");
615 pTemp = mxGetCell(pSeparator, pn);
616 if(pTemp1)
617 pSeppot = marginal_spPot_to_spPot(pClpot, pTemp, maximize);
618 else pSeppot = marginal_null_to_spPot(pClpot, pTemp, maximize);
619 mxSetCell(plhs[1], pn, pSeppot);
620
621 pClpot = mxGetCell(plhs[0], p);
622 pTemp1 = mxGetField(pClpot, 0, "T");
623 if(pTemp1)
624 multiply_spPot_by_spPot(pClpot, pSeppot);
625 else multiply_null_by_spPot(pClpot, pSeppot);
626 }
627 free(collect_order);
628 }
629
630
631
632
633
634