diff toolboxes/FullBNT-1.0.7/bnt/general/solve_limid.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/FullBNT-1.0.7/bnt/general/solve_limid.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,68 @@
+function [strategy, MEU, niter] = solve_limid(engine, varargin)
+% SOLVE_LIMID Find the (locally) optimal strategy for a LIMID
+% [strategy, MEU, niter] = solve_limid(inf_engine, ...)
+%
+% strategy{d} = stochastic policy for node d (a decision node)
+% MEU = maximum expected utility
+% niter = num iterations used
+%
+% The following optional arguments can be specified in the form of name/value pairs:
+% [default in brackets]
+%
+% max_iter - max. num. iterations [ 1 ]
+% tol - tolerance required of consecutive MEU values, used to assess convergence [1e-3]
+% order - order in which decision nodes are optimized [ reverse numerical order ]
+%
+% e.g., solve_limid(engine, 'tol', 1e-2, 'max_iter', 10)
+
+bnet = bnet_from_engine(engine);
+
+% default values
+max_iter = 1;
+tol = 1e-3;
+D = bnet.decision_nodes;
+order = D(end:-1:1);
+
+args = varargin;
+nargs = length(args);
+for i=1:2:nargs
+  switch args{i},
+   case 'max_iter', max_iter  = args{i+1}; 
+   case 'tol',      tol = args{i+1}; 
+   case 'order',    order = args{i+1}; 
+   otherwise,  
+    error(['invalid argument name ' args{i}]);       
+  end
+end
+
+CPDs = bnet.CPD;
+ns = bnet.node_sizes;
+N = length(ns);
+evidence = cell(1,N);
+strategy = cell(1, N);
+
+iter = 1;
+converged = 0;
+oldMEU = 0;
+while ~converged & (iter <= max_iter)
+  for d=order(:)'
+    engine = enter_evidence(engine, evidence, 'exclude', d);
+    [m, pot] = marginal_family(engine, d);
+    %pot = marginal_family_pot(engine, d);
+    [policy, score] = upot_to_opt_policy(pot);    
+    e = bnet.equiv_class(d);
+    CPDs{e} = set_fields(CPDs{e}, 'policy', policy);
+    engine = update_engine(engine, CPDs);
+    strategy{d} = policy;
+  end  
+  engine = enter_evidence(engine, evidence);
+  [m, pot] = marginal_nodes(engine, []);
+  %pot = marginal_family_pot(engine, []);
+  [dummy, MEU] = upot_to_opt_policy(pot);    
+  if approxeq(MEU, oldMEU, tol)
+    converged = 1;
+  end
+  oldMEU = MEU;
+  iter = iter + 1;
+end
+niter = iter - 1;