comparison toolboxes/FullBNT-1.0.7/bnt/inference/static/@jtree_sparse_inf_engine/init_pot.c @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e9a9cd732c1e
1 /* C mex init_pot for in @jtree_sparse_inf_engine directory */
2 /* The file enter_evidence.m in directory @jtree_sparse_inf_engine call it*/
3
4 /**************************************/
5 /* init_pot.c has 5 input & 2 output */
6 /* engine */
7 /* clqs */
8 /* pots */
9 /* pot_type */
10 /* onodes */
11 /* */
12 /* clpot */
13 /* seppot */
14 /**************************************/
15 #include <math.h>
16 #include <stdlib.h>
17 #include "mex.h"
18
19 int compare(const void* src1, const void* src2){
20 int i1 = *(int*)src1 ;
21 int i2 = *(int*)src2 ;
22 return i1-i2 ;
23 }
24
25 void ind_subv(int index, const int *cumprod, int n, int *bsubv){
26 int i;
27
28 for (i = n-1; i >= 0; i--) {
29 bsubv[i] = ((int)floor(index / cumprod[i]));
30 index = index % cumprod[i];
31 }
32 }
33
34 int subv_ind(const int n, const int *cumprod, const int *subv){
35 int i, index=0;
36
37 for(i=0; i<n; i++){
38 index += subv[i] * cumprod[i];
39 }
40 return index;
41 }
42
43 void compute_fixed_weight(int *weight, const double *pbSize, const int *dmask, const int *bCumprod, const int ND, const int diffdim){
44 int i, j;
45 int *eff_cumprod, *subv, *diffsize, *diff_cumprod;
46
47 subv = malloc(diffdim * sizeof(int));
48 eff_cumprod = malloc(diffdim * sizeof(int));
49 diffsize = malloc(diffdim * sizeof(int));
50 diff_cumprod = malloc(diffdim * sizeof(int));
51 for(i=0; i<diffdim; i++){
52 eff_cumprod[i] = bCumprod[dmask[i]];
53 diffsize[i] = (int)pbSize[dmask[i]];
54 }
55 diff_cumprod[0] = 1;
56 for(i=0; i<diffdim-1; i++){
57 diff_cumprod[i+1] = diff_cumprod[i] * diffsize[i];
58 }
59 for(i=0; i<ND; i++){
60 ind_subv(i, diff_cumprod, diffdim, subv);
61 weight[i] = 0;
62 for(j=0; j<diffdim; j++){
63 weight[i] += eff_cumprod[j] * subv[j];
64 }
65 }
66 free(eff_cumprod);
67 free(subv);
68 free(diffsize);
69 free(diff_cumprod);
70 }
71
72 void reset_nzmax(mxArray *spArray, const int old_nzmax, const int new_nzmax){
73 double *ptr;
74 void *newptr;
75 int *ir, *jc;
76 int nbytes;
77
78 if(new_nzmax == old_nzmax) return;
79 nbytes = new_nzmax * sizeof(*ptr);
80 ptr = mxGetPr(spArray);
81 newptr = mxRealloc(ptr, nbytes);
82 mxSetPr(spArray, newptr);
83 nbytes = new_nzmax * sizeof(*ir);
84 ir = mxGetIr(spArray);
85 newptr = mxRealloc(ir, nbytes);
86 mxSetIr(spArray, newptr);
87 jc = mxGetJc(spArray);
88 jc[0] = 0;
89 jc[1] = new_nzmax;
90 mxSetNzmax(spArray, new_nzmax);
91 }
92
93 mxArray* convert_ill_table_to_sparse(const double *bigTable, const int *sequence, const int nzCounts, const int NB){
94 mxArray *spTable;
95 int i, temp, *irs, *jcs, count=0;
96 double *sr;
97
98 spTable = mxCreateSparse(NB, 1, nzCounts, mxREAL);
99 sr = mxGetPr(spTable);
100 irs = mxGetIr(spTable);
101 jcs = mxGetJc(spTable);
102
103 jcs[0] = 0;
104 jcs[1] = nzCounts;
105
106 for(i=0; i<nzCounts; i++){
107 irs[i] = sequence[count];
108 count++;
109 temp = sequence[count];
110 sr[i] = bigTable[temp];
111 count++;
112 }
113 return spTable;
114 }
115
116 void multiply_null_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
117 int i, j, count, NB, NS, siz_b, siz_s, ndim, nzCounts=0;
118 int *mask, *sx, *sy, *cpsy, *subs, *s, *cpsy2, *bir, *bjc;
119 double *pbDomain, *psDomain, *pbSize, *psSize, *spr, *bpr, value;
120 mxArray *pTemp, *pTemp1;
121
122 pTemp = mxGetField(bigPot, 0, "domain");
123 pbDomain = mxGetPr(pTemp);
124 siz_b = mxGetNumberOfElements(pTemp);
125 pTemp = mxGetField(smallPot, 0, "domain");
126 psDomain = mxGetPr(pTemp);
127 siz_s = mxGetNumberOfElements(pTemp);
128
129 pTemp = mxGetField(bigPot, 0, "sizes");
130 pbSize = mxGetPr(pTemp);
131 pTemp = mxGetField(smallPot, 0, "sizes");
132 psSize = mxGetPr(pTemp);
133
134 NB = 1;
135 for(i=0; i<siz_b; i++){
136 NB *= (int)pbSize[i];
137 }
138 NS = 1;
139 for(i=0; i<siz_s; i++){
140 NS *= (int)psSize[i];
141 }
142
143 pTemp = mxGetField(smallPot, 0, "T");
144 spr = mxGetPr(pTemp);
145
146 pTemp1 = mxCreateSparse(NB, 1, NB, mxREAL);
147 bpr = mxGetPr(pTemp1);
148 bir = mxGetIr(pTemp1);
149 bjc = mxGetJc(pTemp1);
150 bjc[0] = 0;
151 bjc[1] = NB;
152
153 if(NS == 1){
154 value = *spr;
155 for(i=0; i<NB; i++){
156 bpr[i] = value;
157 bir[i] = i;
158 }
159 nzCounts = NB;
160 pTemp = mxGetField(bigPot, 0, "T");
161 if(pTemp)mxDestroyArray(pTemp);
162 reset_nzmax(pTemp1, NB, nzCounts);
163 mxSetField(bigPot, 0, "T", pTemp1);
164 return;
165 }
166
167 if(NS == NB){
168 for(i=0; i<NB; i++){
169 if(spr[i] != 0){
170 bpr[nzCounts] = spr[i];
171 bir[nzCounts] = i;
172 nzCounts++;
173 }
174 }
175 pTemp = mxGetField(bigPot, 0, "T");
176 if(pTemp)mxDestroyArray(pTemp);
177 reset_nzmax(pTemp1, NB, nzCounts);
178 mxSetField(bigPot, 0, "T", pTemp1);
179 return;
180 }
181
182 mask = malloc(siz_s * sizeof(int));
183 count = 0;
184 for(i=0; i<siz_s; i++){
185 for(j=0; j<siz_b; j++){
186 if(psDomain[i] == pbDomain[j]){
187 mask[count] = j;
188 count++;
189 break;
190 }
191 }
192 }
193
194 ndim = siz_b;
195 sx = (int *)malloc(sizeof(int)*ndim);
196 sy = (int *)malloc(sizeof(int)*ndim);
197 for(i=0; i<ndim; i++){
198 sx[i] = (int)pbSize[i];
199 sy[i] = 1;
200 }
201 for(i=0; i<count; i++){
202 sy[mask[i]] = sx[mask[i]];
203 }
204
205 s = (int *)malloc(sizeof(int)*ndim);
206 *(cpsy = (int *)malloc(sizeof(int)*ndim)) = 1;
207 subs = (int *)malloc(sizeof(int)*ndim);
208 cpsy2 = (int *)malloc(sizeof(int)*ndim);
209 for(i = 0; i < ndim; i++){
210 subs[i] = 0;
211 s[i] = sx[i] - 1;
212 }
213
214 for(i = 0; i < ndim-1; i++){
215 cpsy[i+1] = cpsy[i]*sy[i]--;
216 cpsy2[i] = cpsy[i]*sy[i];
217 }
218 cpsy2[ndim-1] = cpsy[ndim-1]*(--sy[ndim-1]);
219
220 for(j=0; j<NB; j++){
221 if(*spr != 0){
222 bpr[nzCounts] = *spr;
223 bir[nzCounts] = j;
224 nzCounts++;
225 }
226 for(i = 0; i < ndim; i++){
227 if(subs[i] == s[i]){
228 subs[i] = 0;
229 if(sy[i])
230 spr -= cpsy2[i];
231 }
232 else{
233 subs[i]++;
234 if(sy[i])
235 spr += cpsy[i];
236 break;
237 }
238 }
239 }
240
241 pTemp = mxGetField(bigPot, 0, "T");
242 if(pTemp)mxDestroyArray(pTemp);
243 reset_nzmax(pTemp1, NB, nzCounts);
244 mxSetField(bigPot, 0, "T", pTemp1);
245
246 free(sx);
247 free(sy);
248 free(s);
249 free(cpsy);
250 free(subs);
251 free(cpsy2);
252 free(mask);
253 }
254
255 void multiply_null_by_spPot(mxArray *bigPot, const mxArray *smallPot){
256 int i, j, count, count1, match, temp, bdim, sdim, diffdim, NB, NS, ND, NZB, NZS, bindex, sindex, nzCounts=0;
257 int *samemask, *diffmask, *sir, *sjc, *bCumprod, *sCumprod, *ssubv, *sequence, *weight;
258 double *bigTable, *pbDomain, *psDomain, *pbSize, *psSize, *spr;
259 mxArray *pTemp, *pTemp1;
260
261 pTemp = mxGetField(bigPot, 0, "domain");
262 pbDomain = mxGetPr(pTemp);
263 bdim = mxGetNumberOfElements(pTemp);
264 pTemp = mxGetField(smallPot, 0, "domain");
265 psDomain = mxGetPr(pTemp);
266 sdim = mxGetNumberOfElements(pTemp);
267
268 pTemp = mxGetField(bigPot, 0, "sizes");
269 pbSize = mxGetPr(pTemp);
270 pTemp = mxGetField(smallPot, 0, "sizes");
271 psSize = mxGetPr(pTemp);
272
273 NB = 1;
274 for(i=0; i<bdim; i++){
275 NB *= (int)pbSize[i];
276 }
277 NS = 1;
278 for(i=0; i<sdim; i++){
279 NS *= (int)psSize[i];
280 }
281 ND = NB / NS;
282
283 if(ND == 1){
284 pTemp = mxGetField(bigPot, 0, "T");
285 if(pTemp)mxDestroyArray(pTemp);
286 pTemp1 = mxGetField(smallPot, 0, "T");
287 pTemp = mxDuplicateArray(pTemp1);
288 mxSetField(bigPot, 0, "T", pTemp);
289 return;
290 }
291
292 pTemp = mxGetField(smallPot, 0, "T");
293 spr = mxGetPr(pTemp);
294 sir = mxGetIr(pTemp);
295 sjc = mxGetJc(pTemp);
296 NZS = sjc[1];
297
298 NZB = ND * NZS;
299
300 diffdim = bdim - sdim;
301 sequence = malloc(NZB * 2 * sizeof(int));
302 bigTable = malloc(NZB * sizeof(double));
303 samemask = malloc(sdim * sizeof(int));
304 diffmask = malloc(diffdim * sizeof(int));
305 bCumprod = malloc(bdim * sizeof(int));
306 sCumprod = malloc(sdim * sizeof(int));
307 weight = malloc(ND * sizeof(int));
308 ssubv = malloc(sdim * sizeof(int));
309
310 count = 0;
311 count1 = 0;
312 for(i=0; i<bdim; i++){
313 match = 0;
314 for(j=0; j<sdim; j++){
315 if(pbDomain[i] == psDomain[j]){
316 samemask[count] = i;
317 match = 1;
318 count++;
319 break;
320 }
321 }
322 if(match == 0){
323 diffmask[count1] = i;
324 count1++;
325 }
326 }
327
328 bCumprod[0] = 1;
329 for(i=0; i<bdim-1; i++){
330 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
331 }
332 sCumprod[0] = 1;
333 for(i=0; i<sdim-1; i++){
334 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
335 }
336
337 count = 0;
338 compute_fixed_weight(weight, pbSize, diffmask, bCumprod, ND, diffdim);
339 for(i=0; i<NZS; i++){
340 sindex = sir[i];
341 ind_subv(sindex, sCumprod, sdim, ssubv);
342 temp = 0;
343 for(j=0; j<sdim; j++){
344 temp += ssubv[j] * bCumprod[samemask[j]];
345 }
346 for(j=0; j<ND; j++){
347 bindex = weight[j] + temp;
348 bigTable[nzCounts] = spr[i];
349 sequence[count] = bindex;
350 count++;
351 sequence[count] = nzCounts;
352 nzCounts++;
353 count++;
354 }
355 }
356
357 pTemp = mxGetField(bigPot, 0, "T");
358 if(pTemp)mxDestroyArray(pTemp);
359 qsort(sequence, nzCounts, sizeof(int) * 2, compare);
360 pTemp = convert_ill_table_to_sparse(bigTable, sequence, nzCounts, NB);
361 mxSetField(bigPot, 0, "T", pTemp);
362
363 free(sequence);
364 free(bigTable);
365 free(samemask);
366 free(diffmask);
367 free(bCumprod);
368 free(sCumprod);
369 free(weight);
370 free(ssubv);
371 }
372
373 void multiply_spPot_by_fuPot(mxArray *bigPot, const mxArray *smallPot){
374 int i, j, count, bdim, sdim, NB, NZB, bindex, sindex, nzCounts=0;
375 int *mask, *bir, *bjc, *rir, *rjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
376 double *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, *rpr, value;
377 mxArray *pTemp, *pTemp1;
378
379 pTemp = mxGetField(bigPot, 0, "domain");
380 pbDomain = mxGetPr(pTemp);
381 bdim = mxGetNumberOfElements(pTemp);
382 pTemp = mxGetField(smallPot, 0, "domain");
383 psDomain = mxGetPr(pTemp);
384 sdim = mxGetNumberOfElements(pTemp);
385
386 pTemp = mxGetField(bigPot, 0, "sizes");
387 pbSize = mxGetPr(pTemp);
388 pTemp = mxGetField(smallPot, 0, "sizes");
389 psSize = mxGetPr(pTemp);
390
391 NB = 1;
392 for(i=0; i<bdim; i++){
393 NB *= (int)pbSize[i];
394 }
395
396 pTemp = mxGetField(bigPot, 0, "T");
397 bpr = mxGetPr(pTemp);
398 bir = mxGetIr(pTemp);
399 bjc = mxGetJc(pTemp);
400 NZB = bjc[1];
401
402 pTemp = mxGetField(smallPot, 0, "T");
403 spr = mxGetPr(pTemp);
404
405 pTemp1 = mxCreateSparse(NB, 1, NZB, mxREAL);
406 rpr = mxGetPr(pTemp1);
407 rir = mxGetIr(pTemp1);
408 rjc = mxGetJc(pTemp1);
409 rjc[0] = 0;
410 rjc[1] = NZB;
411
412 mask = malloc(sdim * sizeof(int));
413 bCumprod = malloc(bdim * sizeof(int));
414 sCumprod = malloc(sdim * sizeof(int));
415 bsubv = malloc(bdim * sizeof(int));
416 ssubv = malloc(sdim * sizeof(int));
417
418 count = 0;
419 for(i=0; i<sdim; i++){
420 for(j=0; j<bdim; j++){
421 if(psDomain[i] == pbDomain[j]){
422 mask[count] = j;
423 count++;
424 break;
425 }
426 }
427 }
428
429 bCumprod[0] = 1;
430 for(i=0; i<bdim-1; i++){
431 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
432 }
433 sCumprod[0] = 1;
434 for(i=0; i<sdim-1; i++){
435 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
436 }
437
438 for(i=0; i<NZB; i++){
439 bindex = bir[i];
440 ind_subv(bindex, bCumprod, bdim, bsubv);
441 for(j=0; j<sdim; j++){
442 ssubv[j] = bsubv[mask[j]];
443 }
444 sindex = subv_ind(sdim, sCumprod, ssubv);
445 value = spr[sindex];
446 if(value != 0){
447 rpr[nzCounts] = bpr[i] * value;
448 rir[nzCounts] = bindex;
449 nzCounts++;
450 }
451 }
452
453 pTemp = mxGetField(bigPot, 0, "T");
454 if(pTemp)mxDestroyArray(pTemp);
455 reset_nzmax(pTemp1, NZB, nzCounts);
456 mxSetField(bigPot, 0, "T", pTemp1);
457
458 free(mask);
459 free(bCumprod);
460 free(sCumprod);
461 free(bsubv);
462 free(ssubv);
463 }
464
465 void multiply_spPot_by_spPot(mxArray *bigPot, const mxArray *smallPot){
466 int i, j, count, bdim, sdim, NB, NZB, NZS, position, bindex, sindex, nzCounts=0;
467 int *mask, *result, *bir, *sir, *rir, *bjc, *sjc, *rjc, *bCumprod, *sCumprod, *bsubv, *ssubv;
468 double *pbDomain, *psDomain, *pbSize, *psSize, *bpr, *spr, *rpr;
469 mxArray *pTemp, *pTemp1;
470
471 pTemp = mxGetField(bigPot, 0, "domain");
472 pbDomain = mxGetPr(pTemp);
473 bdim = mxGetNumberOfElements(pTemp);
474 pTemp = mxGetField(smallPot, 0, "domain");
475 psDomain = mxGetPr(pTemp);
476 sdim = mxGetNumberOfElements(pTemp);
477
478 pTemp = mxGetField(bigPot, 0, "sizes");
479 pbSize = mxGetPr(pTemp);
480 pTemp = mxGetField(smallPot, 0, "sizes");
481 psSize = mxGetPr(pTemp);
482
483 NB = 1;
484 for(i=0; i<bdim; i++){
485 NB *= (int)pbSize[i];
486 }
487
488 pTemp = mxGetField(bigPot, 0, "T");
489 bpr = mxGetPr(pTemp);
490 bir = mxGetIr(pTemp);
491 bjc = mxGetJc(pTemp);
492 NZB = bjc[1];
493
494 pTemp = mxGetField(smallPot, 0, "T");
495 spr = mxGetPr(pTemp);
496 sir = mxGetIr(pTemp);
497 sjc = mxGetJc(pTemp);
498 NZS = sjc[1];
499
500 pTemp1 = mxCreateSparse(NB, 1, NZB, mxREAL);
501 rpr = mxGetPr(pTemp1);
502 rir = mxGetIr(pTemp1);
503 rjc = mxGetJc(pTemp1);
504 rjc[0] = 0;
505 rjc[1] = NZB;
506
507 mask = malloc(sdim * sizeof(int));
508 bCumprod = malloc(bdim * sizeof(int));
509 sCumprod = malloc(sdim * sizeof(int));
510 bsubv = malloc(bdim * sizeof(int));
511 ssubv = malloc(sdim * sizeof(int));
512
513 count = 0;
514 for(i=0; i<sdim; i++){
515 for(j=0; j<bdim; j++){
516 if(psDomain[i] == pbDomain[j]){
517 mask[count] = j;
518 count++;
519 break;
520 }
521 }
522 }
523
524 bCumprod[0] = 1;
525 for(i=0; i<bdim-1; i++){
526 bCumprod[i+1] = bCumprod[i] * (int)pbSize[i];
527 }
528 sCumprod[0] = 1;
529 for(i=0; i<sdim-1; i++){
530 sCumprod[i+1] = sCumprod[i] * (int)psSize[i];
531 }
532
533 for(i=0; i<NZB; i++){
534 bindex = bir[i];
535 ind_subv(bindex, bCumprod, bdim, bsubv);
536 for(j=0; j<sdim; j++){
537 ssubv[j] = bsubv[mask[j]];
538 }
539 sindex = subv_ind(sdim, sCumprod, ssubv);
540 result = (int *) bsearch(&sindex, sir, NZS, sizeof(int), compare);
541 if(result){
542 position = result - sir;
543 rpr[nzCounts] = bpr[i] * spr[position];
544 rir[nzCounts] = bindex;
545 nzCounts++;
546 }
547 }
548
549 pTemp = mxGetField(bigPot, 0, "T");
550 if(pTemp)mxDestroyArray(pTemp);
551 reset_nzmax(pTemp1, NZB, nzCounts);
552 mxSetField(bigPot, 0, "T", pTemp1);
553
554 free(mask);
555 free(bCumprod);
556 free(sCumprod);
557 free(bsubv);
558 free(ssubv);
559 }
560
561
562 void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
563 int i, j, c, loop, nNodes, nCliques, ndomain, ns_num, nOnodes, dims[2];
564 double *pClqs, *pr, *pt, *pSize, *eff_ns;
565 mxArray *pTemp, *pTemp1, *pStruct, *pCliques, *pBigpot, *pSmallpot;
566 const char *field_names[] = {"domain", "T", "sizes"};
567
568 nNodes = mxGetNumberOfElements(prhs[1]);
569 pCliques = mxGetField(prhs[0], 0, "cliques");
570 nCliques = mxGetNumberOfElements(pCliques);
571 pTemp = mxGetField(prhs[0], 0, "actual_node_sizes");
572 ns_num = mxGetNumberOfElements(pTemp);
573 pSize = mxGetPr(pTemp);
574
575 eff_ns = (double *)malloc(ns_num * sizeof(double));
576 for(i=0; i<ns_num; i++) eff_ns[i] = pSize[i];
577 nOnodes = mxGetNumberOfElements(prhs[4]);
578 pr = mxGetPr(prhs[4]);
579 for(i=0; i<nOnodes; i++) eff_ns[(int)pr[i] - 1] = 1;
580
581 plhs[0] = mxCreateCellArray(1, &nCliques);
582 for(i=0; i<nCliques; i++){
583 pStruct = mxCreateStructMatrix(1, 1, 3, field_names);
584 mxSetCell(plhs[0], i, pStruct);
585 pTemp = mxGetCell(pCliques, i);
586 ndomain = mxGetNumberOfElements(pTemp);
587 pt = mxGetPr(pTemp);
588 pTemp1 = mxDuplicateArray(pTemp);
589 mxSetField(pStruct, 0, "domain", pTemp1);
590
591 pTemp = mxCreateDoubleMatrix(1, ndomain, mxREAL);
592 mxSetField(pStruct, 0, "sizes", pTemp);
593 pr = mxGetPr(pTemp);
594 for(j=0; j<ndomain; j++){
595 pr[j] = eff_ns[(int)pt[j]-1];
596 }
597 }
598
599 pClqs = mxGetPr(prhs[1]);
600 for(loop=0; loop<nNodes; loop++){
601 c = (int)pClqs[loop] - 1;
602 pSmallpot = mxGetCell(prhs[2], loop);
603 pTemp = mxGetField(pSmallpot, 0, "T");
604 pBigpot = mxGetCell(plhs[0], c);
605 pTemp1 = mxGetField(pBigpot, 0, "T");
606 if(pTemp1){
607 if(mxIsSparse(pTemp))
608 multiply_spPot_by_spPot(pBigpot, pSmallpot);
609 else multiply_spPot_by_fuPot(pBigpot, pSmallpot);
610 }
611 else{
612 if(mxIsSparse(pTemp))
613 multiply_null_by_spPot(pBigpot, pSmallpot);
614 else multiply_null_by_fuPot(pBigpot, pSmallpot);
615 }
616 }
617
618 free(eff_ns);
619 dims[0] = nCliques;
620 dims[1] = nCliques;
621 plhs[1] = mxCreateCellArray(2, dims);
622 }
623
624