comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@jtree_sparse_inf_engine/old/init_pot.c @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 /* C mex init_pot for in @jtree_sparse_inf_engine directory */
2 /* The file enter_evidence.m in directory @jtree_sparse_inf_engine call it*/
3
4 /**************************************/
5 /* init_pot.c has 6 input & 2 output */
6 /* engine */
7 /* clqs */
8 /* pots */
9 /* pot_type */
10 /* onodes */
11 /* ndx */
12 /* */
13 /* clpot */
14 /* seppot */
15 /**************************************/
16 #include <math.h>
17 #include <search.h>
18 #include "mex.h"
19
20 int compare(const void* src1, const void* src2){
21 int i1 = *(int*)src1 ;
22 int i2 = *(int*)src2 ;
23 return i1-i2 ;
24 }
25
26 void ind_subv(int index, const int *cumprod, int n, int *bsubv){
27 int i;
28
29 for (i = n-1; i >= 0; i--) {
30 bsubv[i] = ((int)floor(index / cumprod[i]));
31 index = index % cumprod[i];
32 }
33 }
34
35 int subv_ind(const int n, const int *cumprod, const int *subv){
36 int i, index=0;
37
38 for(i=0; i<n; i++){
39 index += subv[i] * cumprod[i];
40 }
41 return index;
42 }
43
44 void compute_fixed_weight(int *weight, const double *pbSize, const int *dmask, const int *bCumprod, const int ND, const int diffdim){
45 int i, j;
46 int *eff_cumprod, *subv, *diffsize, *diff_cumprod;
47
48 subv = malloc(diffdim * sizeof(int));
49 eff_cumprod = malloc(diffdim * sizeof(int));
50 diffsize = malloc(diffdim * sizeof(int));
51 diff_cumprod = malloc(diffdim * sizeof(int));
52 for(i=0; i<diffdim; i++){
53 eff_cumprod[i] = bCumprod[dmask[i]];
54 diffsize[i] = (int)pbSize[dmask[i]];
55 }
56 diff_cumprod[0] = 1;
57 for(i=0; i<diffdim-1; i++){
58 diff_cumprod[i+1] = diff_cumprod[i] * diffsize[i];
59 }
60 for(i=0; i<ND; i++){
61 ind_subv(i, diff_cumprod, diffdim, subv);
62 weight[i] = 0;
63 for(j=0; j<diffdim; j++){
64 weight[i] += eff_cumprod[j] * subv[j];
65 }
66 }
67 free(eff_cumprod);
68 free(subv);
69 free(diffsize);
70 free(diff_cumprod);
71 }
72
73 mxArray* convert_to_sparse(const double *table, const int NB, const int counts){
74 mxArray *spTable;
75 int i, k, *ir, *jc;
76 double *sr;
77
78 spTable = mxCreateSparse(NB, 1, counts, mxREAL);
79 sr = mxGetPr(spTable);
80 ir = mxGetIr(spTable);
81 jc = mxGetJc(spTable);
82
83 k = 0;
84 jc[0] = 0;
85 jc[1] = counts;
86 for(i=0; i<NB; i++){
87 if(table[i] != 0.0){
88 sr[k] = table[i];
89 ir[k] = i;
90 k++;
91 }
92 }
93
94 return spTable;
95 }
96
97 mxArray* convert_table_to_sparse(const double *bT, const int *index, const int nzCounts, const int NB){
98 mxArray *spTable;
99 int i, *irs, *jcs;
100 double *sr;
101
102 spTable = mxCreateSparse(NB, 1, nzCounts, mxREAL);
103 sr = mxGetPr(spTable);
104 irs = mxGetIr(spTable);
105 jcs = mxGetJc(spTable);
106
107 jcs[0] = 0;
108 jcs[1] = nzCounts;
109
110 for(i=0; i<nzCounts; i++){
111 sr[i] = bT[i];
112 irs[i] = index[i];
113 }
114 return spTable;
115 }
116
117 mxArray* convert_ill_table_to_sparse(const double *bigTable, const int *sequence, const int nzCounts, const int NB){
118 mxArray *spTable;
119 int i, temp, *irs, *jcs, count=0;
120 double *sr;
121
122 spTable = mxCreateSparse(NB, 1, nzCounts, mxREAL);
123 sr = mxGetPr(spTable);
124 irs = mxGetIr(spTable);
125 jcs = mxGetJc(spTable);
126
127 jcs[0] = 0;
128 jcs[1] = nzCounts;
129
130 for(i=0; i<nzCounts; i++){
131 irs[i] = sequence[count];
132 count++;
133 temp = sequence[count];
134 sr[i] = bigTable[temp];
135 count++;
136 }
137 return spTable;
138 }
139
140 void multiply_null_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
141 int i, j, count, NB, NS, siz_b, siz_s, ndim, nzCounts=0;
142 int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2, *jc;
143 double *pbDomain, *psDomain, *pbSize, *psSize, *bTable, *sTable, value;
144 mxArray *pTemp, *pTemp1;
145
146 pTemp = mxGetField(bigPot, 0, "domain");
147 pbDomain = mxGetPr(pTemp);
148 siz_b = mxGetNumberOfElements(pTemp);
149 pTemp = mxGetField(smallPot, 0, "domain");
150 psDomain = mxGetPr(pTemp);
151 siz_s = mxGetNumberOfElements(pTemp);
152
153 pTemp = mxGetField(bigPot, 0, "sizes");
154 pbSize = mxGetPr(pTemp);
155 pTemp = mxGetField(smallPot, 0, "sizes");
156 psSize = mxGetPr(pTemp);
157
158 NB = 1;
159 for(i=0; i<siz_b; i++){
160 NB *= (int)pbSize[i];
161 }
162 NS = 1;
163 for(i=0; i<siz_s; i++){
164 NS *= (int)psSize[i];
165 }
166
167 pTemp = mxGetField(smallPot, 0, "T");
168 sTable = mxGetPr(pTemp);
169 bTable = malloc(NB * sizeof(double));
170 for(i=0; i<NB; i++){
171 bTable[i] = 0;
172 }
173
174 if(NS == 1){
175 value = *sTable;
176 for(i=0; i<NB; i++){
177 bTable[i] = value;
178 }
179 nzCounts = NB;
180 pTemp = mxGetField(bigPot, 0, "T");
181 if(pTemp)mxDestroyArray(pTemp);
182 pTemp = convert_to_sparse(bTable, NB, NB);
183 mxSetField(bigPot, 0, "T", pTemp);
184 free(bTable);
185 return;
186 }
187
188 if(NS == NB){
189 for(i=0; i<NB; i++){
190 bTable[i] = sTable[i];
191 if(sTable[i] != 0) nzCounts++;
192 }
193 pTemp = mxGetField(bigPot, 0, "T");
194 if(pTemp)mxDestroyArray(pTemp);
195 pTemp = convert_to_sparse(bTable, NB, nzCounts);
196 mxSetField(bigPot, 0, "T", pTemp);
197 free(bTable);
198 return;
199 }
200
201 mask = malloc(siz_s * sizeof(int));
202 count = 0;
203 for(i=0; i<siz_s; i++){
204 for(j=0; j<siz_b; j++){
205 if(psDomain[i] == pbDomain[j]){
206 mask[count] = j;
207 count++;
208 break;
209 }
210 }
211 }
212
213 ndim = siz_b;
214 sx = (int *)malloc(sizeof(int)*ndim);
215 sy = (int *)malloc(sizeof(int)*ndim);
216 for(i=0; i<ndim; i++){
217 sx[i] = (int)pbSize[i];
218 sy[i] = 1;
219 }
220 for(i=0; i<count; i++){
221 sy[mask[i]] = sx[mask[i]];
222 }
223
224 s = (int *)malloc(sizeof(int)*ndim);
225 *(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
226 subs = (int *)malloc(sizeof(int)*ndim);
227 cpsy2 = (int *)malloc(sizeof(int)*ndim);
228 for(i = 0; i < ndim; i++){
229 subs[i] = 0;
230 s[i] = sx[i] - 1;
231 }
232
233 for(i = 0; i < ndim-1; i++){
234 cpsy[i+1] = cpsy[i]*sy[i]--;
235 cpsy2[i] = cpsy[i]*sy[i];
236 }
237 cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
238
239 for(j=0; j<NB; j++){
240 bTable[j] = *sTable;
241 if(*sTable != 0.0) nzCounts++;
242 for(i = 0; i < ndim; i++){
243 if(subs[i] == s[i]){
244 subs[i] = 0;
245 if(sy[i])
246 sTable -= cpsy2[i];
247 }
248 else{
249 subs[i]++;
250 if(sy[i])
251 sTable += cpsy[i];
252 break;
253 }
254 }
255 }
256
257 pTemp = mxGetField(bigPot, 0, "T");
258 if(pTemp)mxDestroyArray(pTemp);
259 pTemp = convert_to_sparse(bTable, NB, nzCounts);
260 mxSetField(bigPot, 0, "T", pTemp);
261 pTemp1 = mxGetField(bigPot, 0, "T");
262 jc = mxGetJc(pTemp1);
263
264 free(sx);
265 free(sy);
266 free(s);
267 free(cpsy);
268 free(subs);
269 free(cpsy2);
270 free(mask);
271 free(bTable);
272 }
273
274 void multiply_null_by_spPot(mxArray *bigPot, const mxArray *smallPot){
275 int i, j, count, count1, match, temp, bdim, sdim, diffdim, NB, NS, ND, NZB, NZS, bindex, sindex, nzCounts=0;
276 int *samemask, *diffmask, *sir, *sjc, *bCumprod, *sCumprod, *ssubv, *sequence, *weight;
277 double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *spr;
278 mxArray *pTemp, *pTemp1;
279
280 pTemp = mxGetField(bigPot, 0, "domain");
281 pbDomain = mxGetPr(pTemp);
282 bdim = mxGetNumberOfElements(pTemp);
283 pTemp = mxGetField(smallPot, 0, "domain");
284 psDomain = mxGetPr(pTemp);
285 sdim = mxGetNumberOfElements(pTemp);
286
287 pTemp = mxGetField(bigPot, 0, "sizes");
288 pbSize = mxGetPr(pTemp);
289 pTemp = mxGetField(smallPot, 0, "sizes");
290 psSize = mxGetPr(pTemp);
291
292 NB = 1;
293 for(i=0; i<bdim; i++){
294 NB *= (int)pbSize[i];
295 }
296 NS = 1;
297 for(i=0; i<sdim; i++){
298 NS *= (int)psSize[i];
299 }
300 ND = NB / NS;
301
302 if(ND == 1){
303 pTemp = mxGetField(bigPot, 0, "T");
304 if(pTemp)mxDestroyArray(pTemp);
305 pTemp1 = mxGetField(smallPot, 0, "T");
306 pTemp = mxDuplicateArray(pTemp1);
307 mxSetField(bigPot, 0, "T", pTemp);
308 return;
309 }
310
311 pTemp = mxGetField(smallPot, 0, "T");
312 spr = mxGetPr(pTemp);
313 sir = mxGetIr(pTemp);
314 sjc = mxGetJc(pTemp);
315 NZS = sjc[1];
316
317 NZB = ND * NZS;
318
319 diffdim = bdim - sdim;
320 sequence = malloc(NZB * 2 * sizeof(int));
321 bigTable = malloc(NZB * sizeof(double));
322 samemask = malloc(sdim * sizeof(int));
323 diffmask = malloc(diffdim * sizeof(int));
324 bCumprod = malloc(bdim * sizeof(int));
325 sCumprod = malloc(sdim * sizeof(int));
326 weight = malloc(ND * sizeof(int));
327 ssubv = malloc(sdim * sizeof(int));
328
329 count = 0;
330 count1 = 0;
331 for(i=0; i<bdim; i++){
332 match = 0;
333 for(j=0; j<sdim; j++){
334 if(pbDomain[i] == psDomain[j]){
335 samemask[count] = i;
336 match = 1;
337 count++;
338 break;
339 }
340 }
341 if(match == 0){
342 diffmask[count1] = i;
343 count1++;
344 }
345 }
346
347 bCumprod[0] = 1;
348 for(i=0; i<bdim-1; i++){
349 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
350 }
351 sCumprod[0] = 1;
352 for(i=0; i<sdim-1; i++){
353 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
354 }
355
356 count = 0;
357 compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
358 for(i=0; i<NZS; i++){
359 sindex = sir[i];
360 ind_subv(sindex, sCumprod, sdim, ssubv);
361 temp = 0;
362 for(j=0; j<sdim; j++){
363 temp += ssubv[j] * bCumprod[samemask[j]];
364 }
365 for(j=0; j<ND; j++){
366 bindex = weight[j] + temp;
367 bigTable[nzCounts] = spr[i];
368 sequence[count] = bindex;
369 count++;
370 sequence[count] = nzCounts;
371 nzCounts++;
372 count++;
373 }
374 }
375
376 pTemp = mxGetField(bigPot, 0, "T");
377 if(pTemp)mxDestroyArray(pTemp);
378 qsort(sequence, nzCounts, sizeof(int) * 2, compare);
379 pTemp = convert_ill_table_to_sparse(bigTable, sequence, nzCounts, NB);
380 mxSetField(bigPot, 0, "T", pTemp);
381
382 free(sequence);
383 free(bigTable);
384 free(samemask);
385 free(diffmask);
386 free(bCumprod);
387 free(sCumprod);
388 free(weight);
389 free(ssubv);
390 }
391
392 void multiply_spPot_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
393 int i, j, count, bdim, sdim, NB, NZB, bindex, sindex, nzCounts=0;
394 int *mask, *index, *bir, *bjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
395 double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
396 mxArray *pTemp;
397
398 pTemp = mxGetField(bigPot, 0, "domain");
399 pbDomain = mxGetPr(pTemp);
400 bdim = mxGetNumberOfElements(pTemp);
401 pTemp = mxGetField(smallPot, 0, "domain");
402 psDomain = mxGetPr(pTemp);
403 sdim = mxGetNumberOfElements(pTemp);
404
405 pTemp = mxGetField(bigPot, 0, "sizes");
406 pbSize = mxGetPr(pTemp);
407 pTemp = mxGetField(smallPot, 0, "sizes");
408 psSize = mxGetPr(pTemp);
409
410 NB = 1;
411 for(i=0; i<bdim; i++){
412 NB *= (int)pbSize[i];
413 }
414
415 pTemp = mxGetField(bigPot, 0, "T");
416 bpr = mxGetPr(pTemp);
417 bir = mxGetIr(pTemp);
418 bjc = mxGetJc(pTemp);
419 NZB = bjc[1];
420
421 pTemp = mxGetField(smallPot, 0, "T");
422 spr = mxGetPr(pTemp);
423
424 bigTable = malloc(NZB * sizeof(double));
425 index = malloc(NZB * sizeof(double));
426 mask = malloc(sdim * sizeof(int));
427 bCumprod = malloc(bdim * sizeof(int));
428 sCumprod = malloc(sdim * sizeof(int));
429 bsubv = malloc(bdim * sizeof(int));
430 ssubv = malloc(sdim * sizeof(int));
431
432 for(i=0; i<NZB; i++){
433 bigTable[i] = 0;
434 }
435 count = 0;
436 for(i=0; i<sdim; i++){
437 for(j=0; j<bdim; j++){
438 if(psDomain[i] == pbDomain[j]){
439 mask[count] = j;
440 count++;
441 break;
442 }
443 }
444 }
445
446 bCumprod[0] = 1;
447 for(i=0; i<bdim-1; i++){
448 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
449 }
450 sCumprod[0] = 1;
451 for(i=0; i<sdim-1; i++){
452 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
453 }
454
455 for(i=0; i<NZB; i++){
456 bindex = bir[i];
457 ind_subv(bindex, bCumprod, bdim, bsubv);
458 for(j=0; j<sdim; j++){
459 ssubv[j] = bsubv[mask[j]];
460 }
461 sindex = subv_ind(sdim, sCumprod, ssubv);
462 value = spr[sindex];
463 if(value != 0){
464 bigTable[nzCounts] = bpr[i] * value;
465 index[nzCounts] = bindex;
466 nzCounts++;
467 }
468 }
469
470 pTemp = mxGetField(bigPot, 0, "T");
471 if(pTemp)mxDestroyArray(pTemp);
472 pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
473 mxSetField(bigPot, 0, "T", pTemp);
474
475 free(bigTable);
476 free(index);
477 free(mask);
478 free(bCumprod);
479 free(sCumprod);
480 free(bsubv);
481 free(ssubv);
482 }
483
484 void multiply_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
485 int i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex, nzCounts=0;
486 int *mask, *index, *result, *bir, *sir, *bjc, *sjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
487 double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, value;
488 mxArray *pTemp;
489
490 pTemp = mxGetField(bigPot, 0, "domain");
491 pbDomain = mxGetPr(pTemp);
492 bdim = mxGetNumberOfElements(pTemp);
493 pTemp = mxGetField(smallPot, 0, "domain");
494 psDomain = mxGetPr(pTemp);
495 sdim = mxGetNumberOfElements(pTemp);
496
497 pTemp = mxGetField(bigPot, 0, "sizes");
498 pbSize = mxGetPr(pTemp);
499 pTemp = mxGetField(smallPot, 0, "sizes");
500 psSize = mxGetPr(pTemp);
501
502 NB = 1;
503 for(i=0; i<bdim; i++){
504 NB *= (int)pbSize[i];
505 }
506
507 pTemp = mxGetField(bigPot, 0, "T");
508 bpr = mxGetPr(pTemp);
509 bir = mxGetIr(pTemp);
510 bjc = mxGetJc(pTemp);
511 NZB = bjc[1];
512
513 pTemp = mxGetField(smallPot, 0, "T");
514 spr = mxGetPr(pTemp);
515 sir = mxGetIr(pTemp);
516 sjc = mxGetJc(pTemp);
517 NZS = sjc[1];
518
519 bigTable = malloc(NZB * sizeof(double));
520 index = malloc(NZB * sizeof(double));
521 mask = malloc(sdim * sizeof(int));
522 bCumprod = malloc(bdim * sizeof(int));
523 sCumprod = malloc(sdim * sizeof(int));
524 bsubv = malloc(bdim * sizeof(int));
525 ssubv = malloc(sdim * sizeof(int));
526
527 for(i=0; i<NZB; i++){
528 bigTable[i] = 0;
529 }
530 count = 0;
531 for(i=0; i<sdim; i++){
532 for(j=0; j<bdim; j++){
533 if(psDomain[i] == pbDomain[j]){
534 mask[count] = j;
535 count++;
536 break;
537 }
538 }
539 }
540
541 bCumprod[0] = 1;
542 for(i=0; i<bdim-1; i++){
543 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
544 }
545 sCumprod[0] = 1;
546 for(i=0; i<sdim-1; i++){
547 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
548 }
549
550 for(i=0; i<NZB; i++){
551 value = bpr[i];
552 bindex = bir[i];
553 ind_subv(bindex, bCumprod, bdim, bsubv);
554 for(j=0; j<sdim; j++){
555 ssubv[j] = bsubv[mask[j]];
556 }
557 sindex = subv_ind(sdim, sCumprod, ssubv);
558 result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
559 if(result){
560 position = result - sir;
561 value *= spr[position];
562 bigTable[nzCounts] = value;
563 index[nzCounts] = bindex;
564 nzCounts++;
565 }
566 }
567
568 pTemp = mxGetField(bigPot, 0, "T");
569 if(pTemp)mxDestroyArray(pTemp);
570 pTemp = convert_table_to_sparse(bigTable, index, nzCounts, NB);
571 mxSetField(bigPot, 0, "T", pTemp);
572
573 free(bigTable);
574 free(index);
575 free(mask);
576 free(bCumprod);
577 free(sCumprod);
578 free(bsubv);
579 free(ssubv);
580 }
581
582
583 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
584 int i, j, c, loop, nNodes, nCliques, ndomain, dims[2];
585 double *pClqs, *pr, *pt, *pSize;
586 mxArray *pTemp, *pTemp1, *pStruct, *pCliques, *pBigpot, *pSmallpot;
587 const char *field_names[] = {"domain", "T", "sizes"};
588
589 nNodes = mxGetNumberOfElements(prhs[1]);
590 pCliques = mxGetField(prhs[0], 0, "cliques");
591 nCliques = mxGetNumberOfElements(pCliques);
592 pTemp = mxGetField(prhs[0], 0, "eff_node_sizes");
593 pSize = mxGetPr(pTemp);
594
595 plhs[0] = mxCreateCellArray(1, &nCliques);
596 for(i=0; i<nCliques; i++){
597 pStruct = mxCreateStructMatrix(1, 1, 3, field_names);
598 mxSetCell(plhs[0], i, pStruct);
599 pTemp = mxGetCell(pCliques, i);
600 ndomain = mxGetNumberOfElements(pTemp);
601 pt = mxGetPr(pTemp);
602 pTemp1 = mxDuplicateArray(pTemp);
603 mxSetField(pStruct, 0, "domain", pTemp1);
604
605 pTemp = mxCreateDoubleMatrix(1, ndomain, mxREAL);
606 mxSetField(pStruct, 0, "sizes", pTemp);
607 pr = mxGetPr(pTemp);
608 for(j=0; j<ndomain; j++){
609 pr[j] = pSize[(int)pt[j]-1];
610 }
611 }
612
613 pClqs = mxGetPr(prhs[1]);
614 for(loop=0; loop<nNodes; loop++){
615 c = (int)pClqs[loop] - 1;
616 pSmallpot = mxGetCell(prhs[2], loop);
617 pTemp = mxGetField(pSmallpot, 0, "T");
618 pBigpot = mxGetCell(plhs[0], c);
619 pTemp1 = mxGetField(pBigpot, 0, "T");
620 if(pTemp1){
621 if(mxIsSparse(pTemp))
622 multiply_spPot_by_spPot(pBigpot, pSmallpot);
623 else multiply_spPot_by_fuPot(pBigpot, pSmallpot);
624 }
625 else{
626 if(mxIsSparse(pTemp))
627 multiply_null_by_spPot(pBigpot, pSmallpot);
628 else multiply_null_by_fuPot(pBigpot, pSmallpot);
629 }
630 }
631
632 dims[0] = nCliques;
633 dims[1] = nCliques;
634 plhs[1] = mxCreateCellArray(2, dims);
635 }
636
637