annotate toolboxes/FullBNT-1.0.7/bnt/examples/dynamic/cmp_online_inference.m @ 0:cc4b1211e677 tip

initial commit to HG from Changeset: 646 (e263d8a21543) added further path and more save "camirversion.m"
author Daniel Wolff
date Fri, 19 Aug 2016 13:07:06 +0200
parents
children
rev   line source
Daniel@0 1 function [time, engine] = cmp_online_inference(bnet, engine, T, varargin)
Daniel@0 2 % CMP_ONLINE_INFERENCE Compare several online inference engines on a DBN
Daniel@0 3 % function [time, engine] = cmp_online_inference(bnet, engine, T, ...)
Daniel@0 4 %
Daniel@0 5 % engine{i} is the i'th inference engine.
Daniel@0 6 % time(e) = elapsed time for doing inference with engine e
Daniel@0 7 %
Daniel@0 8 % The list below gives optional arguments [default value in brackets].
Daniel@0 9 %
Daniel@0 10 % exact - specifies which engines do exact inference [ 1:length(engine) ]
Daniel@0 11 % singletons_only - if 1, we only call marginal_nodes, else this and marginal_family [0]
Daniel@0 12 % check_ll - 1 means we check that the log-likelihoods are correct [1]
Daniel@0 13
Daniel@0 14 % set default params
Daniel@0 15 exact = 1:length(engine);
Daniel@0 16 singletons_only = 0;
Daniel@0 17 check_ll = 1;
Daniel@0 18 onodes = bnet.observed;
Daniel@0 19
Daniel@0 20 args = varargin;
Daniel@0 21 nargs = length(args);
Daniel@0 22 for i=1:2:nargs
Daniel@0 23 switch args{i},
Daniel@0 24 case 'exact', exact = args{i+1};
Daniel@0 25 case 'singletons_only', singletons_only = args{i+1};
Daniel@0 26 case 'check_ll', check_ll = args{i+1};
Daniel@0 27 case 'observed', onodes = args{i+1};
Daniel@0 28 otherwise,
Daniel@0 29 error(['unrecognized argument ' args{i}])
Daniel@0 30 end
Daniel@0 31 end
Daniel@0 32
Daniel@0 33 E = length(engine);
Daniel@0 34 ref = exact(1); % reference
Daniel@0 35 cmp = mysetdiff(exact, ref);
Daniel@0 36
Daniel@0 37 ss = length(bnet.intra);
Daniel@0 38 hnodes = mysetdiff(1:ss, onodes);
Daniel@0 39 ev = sample_dbn(bnet, 'length', T);
Daniel@0 40 evidence = cell(ss,T);
Daniel@0 41 evidence(onodes,:) = ev(onodes, :);
Daniel@0 42
Daniel@0 43 time = zeros(1,E);
Daniel@0 44 for t=1:T
Daniel@0 45 for e=1:E
Daniel@0 46 tic;
Daniel@0 47 [engine{e}, ll(e)] = enter_evidence(engine{e}, evidence(:,t), t);
Daniel@0 48 time(e)= time(e) + toc;
Daniel@0 49 end
Daniel@0 50 if check_ll
Daniel@0 51 for e=cmp(:)'
Daniel@0 52 if ~approxeq(ll(ref), ll(e))
Daniel@0 53 error(['engine ' num2str(e) ' has wrong ll'])
Daniel@0 54 end
Daniel@0 55 end
Daniel@0 56 end
Daniel@0 57 if ~singletons_only
Daniel@0 58 check_marginals(engine, hnodes, exact, 0, t);
Daniel@0 59 end
Daniel@0 60 check_marginals(engine, hnodes, exact, 1, t);
Daniel@0 61 end
Daniel@0 62
Daniel@0 63
Daniel@0 64 %%%%%%%%%%
Daniel@0 65
Daniel@0 66 function check_marginals(engine, hnodes, exact, singletons, t)
Daniel@0 67
Daniel@0 68 bnet = bnet_from_engine(engine{1});
Daniel@0 69 N = length(bnet.intra);
Daniel@0 70 cnodes_bitv = zeros(1,N);
Daniel@0 71 cnodes_bitv(bnet.cnodes) = 1;
Daniel@0 72 ref = exact(1); % reference
Daniel@0 73 cmp = exact(2:end);
Daniel@0 74 E = length(engine);
Daniel@0 75 m = cell(1,E);
Daniel@0 76
Daniel@0 77 for n=1:N
Daniel@0 78 %for n=hnodes(:)'
Daniel@0 79 for e=1:E
Daniel@0 80 if singletons
Daniel@0 81 m{e} = marginal_nodes(engine{e}, n, t);
Daniel@0 82 else
Daniel@0 83 m{e} = marginal_family(engine{e}, n, t);
Daniel@0 84 end
Daniel@0 85 end
Daniel@0 86 for e=cmp(:)'
Daniel@0 87 assert(isequal(m{e}.domain, m{ref}.domain));
Daniel@0 88 if cnodes_bitv(n) & isfield(m{e}, 'mu') & isfield(m{ref}, 'mu')
Daniel@0 89 wrong = ~approxeq(m{ref}.mu, m{e}.mu) | ~approxeq(m{ref}.Sigma, m{e}.Sigma);
Daniel@0 90 else
Daniel@0 91 wrong = ~approxeq(m{ref}.T(:), m{e}.T(:));
Daniel@0 92 end
Daniel@0 93 if wrong
Daniel@0 94 error(sprintf('engine %d is wrong; n=%d, t=%d, fam=%d', e, n, t, ~singletons))
Daniel@0 95 end
Daniel@0 96 end
Daniel@0 97 end