Mercurial > hg > camir-aes2014
diff toolboxes/FullBNT-1.0.7/bnt/inference/static/@gibbs_sampling_inf_engine/private/get_slice_dbn.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/toolboxes/FullBNT-1.0.7/bnt/inference/static/@gibbs_sampling_inf_engine/private/get_slice_dbn.m Tue Feb 10 15:05:51 2015 +0000 @@ -0,0 +1,87 @@ +function slice = get_slice_dbn(bnet, state, i, n, j, m, strides, families, ... + CPT) +% slice = get_slice(bnet, state, i, n, j, m, strides, families, cpt) +% +% GET_SLICE get one-dimensional slice of the CPT for node X_i^n +% that corresponds to the different values of X_j^m, where all +% other nodes have values given by state. +% strides is the result of +% calling compute_strides(bnet) +% families is the result of calling compute_families(bnet) +% cpts is the result of calling get_cpts(bnet) +% +% slice is a 1-d array + + +if (n == 1) + + k = bnet.eclass1(i); + c = CPT{k}; + + % Figure out evidence on family + fam = families{i, 1}; + ev = state(fam, 1); + + % Remove evidence on node j + pos = find(fam == j); + ev(pos) = 1; + dim = size(ev, 1); + + % Compute initial index and stride + start_ind = 1+strides(k, 1:dim)*(ev-1); + stride = strides(k, pos); + + % Compute the slice + slice = c(start_ind:stride:start_ind+(bnet.node_sizes(j, 1)-1)*stride); + +else + + k = bnet.eclass2(i); + c = CPT{k}; + + fam = families{i, 2}; + ss = length(bnet.intra); + + % Divide the family into nodes in this time step and nodes in the + % previous time step + this_time_step = fam(find(fam > ss)); + prev_time_step = fam(find(fam <= ss)); + + % Normalize the node numbers + this_time_step = this_time_step - ss; + + % Get the evidence + this_step_ev = state(this_time_step, n); + prev_step_ev = state(prev_time_step, n-1); + + % Remove the evidence for X_j^m + if (m == n) + pos = find(this_time_step == j); + this_step_ev(pos) = 1; + pos = pos + size(prev_time_step, 2); + else + assert (m == n-1); + pos = find(prev_time_step == j); + prev_step_ev(pos) = 1; + end + + % Combine the two time steps + ev = [prev_step_ev; this_step_ev]; + dim = size(ev, 1); + + + % Compute starting index and stride + start_ind = 1 + strides(k, 1:dim)*(ev-1); + stride = strides(k, pos); + + % Compute slice + if (m == 1) + q = 1; + else + q = 2; + end + slice = c(start_ind:stride:start_ind+(bnet.node_sizes(j, q)-1)*stride); +end + + +