wolffd@0: #include "mex.h" wolffd@0: wolffd@0: void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray wolffd@0: *prhs[]) wolffd@0: { wolffd@0: double *pn, *pi, *pj, *pm, *y, *ecElts, *pcpt, *famElts, *strideElts, wolffd@0: *ev, *nsElts; wolffd@0: int i, k, j, m, n; wolffd@0: mxArray *ec, *cpt, *fam, *ns; wolffd@0: int c1, famSize, nsj; wolffd@0: int strideStride, startInd, stride, pos, numNodes; wolffd@0: wolffd@0: const int BNET = 0; wolffd@0: const int STATE = 1; wolffd@0: const int STRIDES = 6; wolffd@0: const int FAMILIES = 7; wolffd@0: const int CPT = 8; wolffd@0: wolffd@0: pn = mxGetPr(prhs[3]); wolffd@0: n = (int) pn[0]; wolffd@0: pi = mxGetPr(prhs[2]); wolffd@0: i = (int) pi[0]; wolffd@0: pj = mxGetPr(prhs[4]); wolffd@0: j = (int) pj[0]; wolffd@0: pm = mxGetPr(prhs[5]); wolffd@0: m = (int) pm[0]; wolffd@0: ev = mxGetPr(prhs[STATE]); wolffd@0: ns = mxGetField (prhs[BNET], 0, "node_sizes"); wolffd@0: nsElts = mxGetPr (ns); wolffd@0: numNodes = mxGetM(ns); wolffd@0: wolffd@0: strideStride = mxGetM(prhs[STRIDES]); wolffd@0: strideElts = mxGetPr(prhs[STRIDES]); wolffd@0: wolffd@0: wolffd@0: wolffd@0: /* Treat the case n = 1 separately */ wolffd@0: if (pn[0] == 1) { wolffd@0: wolffd@0: /* Get the appropriate CPT */ wolffd@0: ec = mxGetField (prhs[BNET], 0, "eclass1"); wolffd@0: ecElts = mxGetPr(ec); wolffd@0: k = (int) ecElts[i-1]; wolffd@0: cpt = mxGetCell (prhs[8], k-1); wolffd@0: pcpt = mxGetPr(cpt); wolffd@0: wolffd@0: nsj = (int) nsElts[j-1]; wolffd@0: wolffd@0: /* Get the correct family vector */ wolffd@0: /* (Note : MEX is painful) */ wolffd@0: fam = mxGetCell (prhs[FAMILIES], i - 1); wolffd@0: famSize = mxGetNumberOfElements(fam); wolffd@0: famElts = mxGetPr(fam); wolffd@0: wolffd@0: wolffd@0: /* Figure out starting position and stride */ wolffd@0: startInd = 0; wolffd@0: for (c1 = 0, pos = k-1; c1 < famSize; c1++, pos+=strideStride) { wolffd@0: if (famElts[c1] != j) { wolffd@0: startInd += strideElts[pos]*(ev[(int)famElts[c1]-1]-1); wolffd@0: } wolffd@0: else { wolffd@0: stride = strideElts[pos]; wolffd@0: } wolffd@0: } wolffd@0: wolffd@0: plhs[0] = mxCreateDoubleMatrix (1, nsj, mxREAL); wolffd@0: y = mxGetPr(plhs[0]); wolffd@0: for (c1 = 0, pos = startInd; c1 < nsj; c1++, pos+=stride) { wolffd@0: y[c1] = pcpt[pos]; wolffd@0: } wolffd@0: } wolffd@0: wolffd@0: /* Handle the case n > 1 */ wolffd@0: else { wolffd@0: wolffd@0: /* Get the appropriate CPT */ wolffd@0: ec = mxGetField (prhs[BNET], 0, "eclass2"); wolffd@0: ecElts = mxGetPr(ec); wolffd@0: k = (int) ecElts[i-1]; wolffd@0: cpt = mxGetCell (prhs[8], k-1); wolffd@0: pcpt = mxGetPr(cpt); wolffd@0: wolffd@0: /* Figure out size of slice */ wolffd@0: if (m == 1) { wolffd@0: nsj = (int) nsElts[j-1]; wolffd@0: } wolffd@0: else { wolffd@0: nsj = (int) nsElts[j-1+numNodes]; wolffd@0: } wolffd@0: wolffd@0: /* Figure out family */ wolffd@0: fam = mxGetCell (prhs[FAMILIES], i - 1 + numNodes); wolffd@0: famSize = mxGetNumberOfElements(fam); wolffd@0: famElts = mxGetPr(fam); wolffd@0: wolffd@0: startInd = 0; wolffd@0: for (c1 = 0, pos = k-1; c1 < famSize; c1++, pos+=strideStride) { wolffd@0: int f = (int) famElts[c1]; wolffd@0: wolffd@0: if (((f == j+numNodes) && (m == n)) || ((f == j) && (m == wolffd@0: n-1))) { wolffd@0: stride = strideElts[pos]; wolffd@0: } wolffd@0: else { wolffd@0: startInd += strideElts[pos] * (ev[f-1+((n-2)*numNodes)]-1); wolffd@0: } wolffd@0: } wolffd@0: wolffd@0: plhs[0] = mxCreateDoubleMatrix(1,nsj, mxREAL); wolffd@0: y = mxGetPr(plhs[0]); wolffd@0: for (c1 = 0, pos = startInd; c1 < nsj; c1++, pos+=stride) { wolffd@0: y[c1] = pcpt[pos]; wolffd@0: } wolffd@0: } wolffd@0: }