diff src/samer/models/MOGVector.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/samer/models/MOGVector.java	Tue Jan 17 17:50:20 2012 +0000
@@ -0,0 +1,284 @@
+package samer.models;
+
+import	samer.core.*;
+import	samer.core.types.*;
+import	samer.maths.*;
+import	samer.tools.*;
+import	java.util.*;
+
+public class MOGVector extends NamedTask
+{
+	Vec		input;
+
+	int		N, M;					// num states, number of inputs
+	VVector	s;						// states
+	VVector	l;						// likelihoods
+	VDouble	L;						// total likelihood
+		
+	Matrix	t;						// deviations: t=(x-mu)/sig
+	Matrix	post;					// posterior over states
+	Matrix	w, mu, sig;			// weights, means, stddevs
+	Matrix	dw, dmu, dsig;		// accumulated stats for learning
+
+	VDouble	nuw, numu, nusig;	// learning rates
+
+
+	// ------------ private bits ------------------------
+
+	private final static double Q=0.5*Math.log(2*Math.PI);
+	private double[] tmp;
+	private double [] _s;
+	private double [] _l;
+	private double [][] _t;
+	private double [][] _w;
+	private double [][] _mu;
+	private double [][] _sig;
+	private double [][] _post;
+	// ---------------------------------------------------
+
+	public MOGVector(Vec in, int n)
+	{
+		super("mog");
+		input=in;
+
+		Shell.push(node);
+		N=n; M=in.size();
+
+		s = new VVector("state",M);
+		l = new VVector("-log p(s)",M);
+		L = new VDouble("likelihood");
+
+		t		= new Matrix("t",N,M);
+		post	= new Matrix("p(s|x)",N,M);
+		mu		= new Matrix("means",N,M);
+		sig	= new Matrix("sigmas",N,M);
+		w		= new Matrix("weight",N,M);
+		dmu	= new Matrix("dmu",N,M);
+		dsig	= new Matrix("dsig",N,M);
+		dw		= new Matrix("dw",N,M);
+
+		tmp=new double[M];
+
+		Shell.pop();
+
+		// initialise parameters
+
+		// weights...
+		w.set(new Constant(1.0/N));
+		w.changed();
+
+		// means...
+		for (int i=0; i<N; i++) {
+			mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i));
+		}
+		mu.changed();
+
+		// sigmas
+		sig.set(new Constant(1));
+		sig.changed();
+
+
+		// ----- initialise private bits ----------
+		_s=s.array();
+		_l=l.array();
+		_t=t.getArray();
+		_w=w.getArray();
+		_mu=mu.getArray();
+		_sig=sig.getArray();
+		_post=post.getArray();
+	}
+
+	// normalise column sum of w, return array of sums
+	private void normalise(double [][] w, double[] sum) {
+		// make sure weights are normalised properly
+		Mathx.zero(sum);
+		for (int i=0; i<N; i++) Mathx.add(sum,w[i]);
+		for (int i=0; i<N; i++) Mathx.div(w[i],sum);
+	}
+
+	public VFunction getPDF() { 
+		PDF		 fn=new PDF();
+		VFunction vfn=new VFunction("pdf",fn);
+		fn.setupObservers(vfn);
+		return vfn; 
+	}
+
+	public void run() 
+	{
+		// matlab equivalent code:
+		// t = (x - mu)./sig
+		// post = (w./sig).*exp(-0.5*t.^2)
+		// [dummy, s] = max(post,2)
+		// Z = sum(post,2)
+		// post = post./(Z*ones(1,N))
+		// Z=Z*Q;
+
+		{
+			// Vec.Iterator it=input.iterator();
+
+			double [] x=input.array();
+
+			for (int i=0; i<N; i++) {					
+				for (int j=0; j<M; j++) {
+					_t[i][j] = (x[j]-_mu[i][j])/_sig[i][j];
+					_post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]);
+				}
+			}
+
+			// this computes partition function 
+			// and normalises posterior in one go
+			normalise(_post,_l);
+
+			// get MAP state
+			for (int j=0; j<M; j++) {
+				int		state=0;	
+				double	pmax=_post[0][j];
+				
+				for (int i=1; i<N; i++) {
+					if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; }
+				}
+				_s[j]=state;
+			}
+
+			for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]);
+
+			L.value = Mathx.sum(_l);
+		}
+
+		L.changed();
+		t.changed();
+		s.changed();
+		l.changed();
+		post.changed();
+	}
+
+	public Task learnTask() {
+		return new AnonymousTask() {
+			// buffer changes to parameters
+			double [][] _dw=dw.getArray();
+			double [][] _dmu=dmu.getArray();
+			double [][] _dsig=dsig.getArray();
+
+			public void starting() { 
+				dw.zero();	
+				dmu.zero();	
+				dsig.zero();
+				normalise(w.getArray(),tmp);
+			}
+
+			public void run() {
+				for (int i=0; i<N; i++) {
+					for (int j=0; j<M; j++) {
+						double pp=_post[i][j], tt=_t[i][j];
+						_dw[i][j]  += pp;
+						_dmu[i][j] += pp*tt;
+						_dsig[i][j]+= pp*(tt*tt-1);
+					}
+				}
+			}
+		};
+	}
+
+	public Task flushTask() { 
+		Shell.push(node);
+
+		try {
+			return new AnonymousTask() {
+
+				VDouble nuw=new VDouble("weights.learn.rate",.001);
+				VDouble numu=new VDouble("means.learn.rate",.001);
+				VDouble nusig=new VDouble("sigmas.learn.rate",.001);
+
+				double [][] _w=w.getArray();
+				double [][] _dw=dw.getArray();
+
+				public void run()
+				{
+					// mu += numu*dmu.*sig
+					// sig += nusig*dsig.*sig;
+					// lambda = sum(w.*dw,2)./sum(w.*w,2)
+					//	dw -= (lambda*ones(1,N)) .* w
+					// w *= exp(nu*dw)
+
+					dmu.arrayTimesEquals(sig);
+					dmu.timesEquals(nusig.value);
+					mu.plusEquals(dmu);
+					mu.changed();
+					dmu.zero();
+
+					dsig.arrayTimesEquals(sig);
+					dsig.timesEquals(nusig.value);
+					sig.plusEquals(dsig);
+					sig.changed();
+					dsig.zero();
+
+					{
+						// the effect of this is to project
+						// dw away from w. the resulting vector
+						// is then added to the log of w
+
+						double nu=nuw.value;
+
+						for (int j=0; j<M; j++) {
+							double w2=0, lambda=0;
+							for (int i=0; i<N; i++) {
+								lambda += _w[i][j]*_dw[i][j];
+								w2     += _w[i][j]*_w[i][j];
+							}
+							tmp[j]=lambda/w2;
+						}
+
+						for (int i=0; i<N; i++) {
+							for (int j=0; j<M; j++) {
+								_w[i][j]  *= Math.exp(
+									nu*(_dw[i][j] - tmp[j]*_w[i][j])
+								);  // update w
+							}
+						}
+					}
+
+					normalise(w.getArray(),tmp);
+					w.changed();
+					dw.zero();	
+				}
+			};
+		} finally { Shell.pop(); }
+	}
+
+	class PDF extends Function implements Observer {
+		VInteger		index;
+		Viewable		vbl=null;
+
+		public PDF() { 
+			index=new VInteger("index",M/2); 
+			index.setRange(0,M-1); 
+			index.addObserver(this);
+		}
+
+		public void dispose() { index.dispose(); }
+
+		public double apply(double x) {
+			double Z=0;
+			for (int i=0; i<N; i++) {
+				int	 j=index.value;
+				double t=(x-_mu[i][j])/_sig[i][j];
+				Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t);
+			}
+			return Z;
+		}
+		public String format(String x) {
+			return "mogpdf("+x+")";
+		}
+
+		public void setupObservers(Viewable v) {
+			w.addObserver(this);
+			mu.addObserver(this);
+			sig.addObserver(this);
+			vbl=v;
+		}
+
+		public void update(Observable o, Object arg) {
+			if (arg!=Viewable.DISPOSING) vbl.changed();
+		}
+	}
+}