view src/samer/models/MOGVector.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line source
package samer.models;

import	samer.core.*;
import	samer.core.types.*;
import	samer.maths.*;
import	samer.tools.*;
import	java.util.*;

public class MOGVector extends NamedTask
{
	Vec		input;

	int		N, M;					// num states, number of inputs
	VVector	s;						// states
	VVector	l;						// likelihoods
	VDouble	L;						// total likelihood
		
	Matrix	t;						// deviations: t=(x-mu)/sig
	Matrix	post;					// posterior over states
	Matrix	w, mu, sig;			// weights, means, stddevs
	Matrix	dw, dmu, dsig;		// accumulated stats for learning

	VDouble	nuw, numu, nusig;	// learning rates


	// ------------ private bits ------------------------

	private final static double Q=0.5*Math.log(2*Math.PI);
	private double[] tmp;
	private double [] _s;
	private double [] _l;
	private double [][] _t;
	private double [][] _w;
	private double [][] _mu;
	private double [][] _sig;
	private double [][] _post;
	// ---------------------------------------------------

	public MOGVector(Vec in, int n)
	{
		super("mog");
		input=in;

		Shell.push(node);
		N=n; M=in.size();

		s = new VVector("state",M);
		l = new VVector("-log p(s)",M);
		L = new VDouble("likelihood");

		t		= new Matrix("t",N,M);
		post	= new Matrix("p(s|x)",N,M);
		mu		= new Matrix("means",N,M);
		sig	= new Matrix("sigmas",N,M);
		w		= new Matrix("weight",N,M);
		dmu	= new Matrix("dmu",N,M);
		dsig	= new Matrix("dsig",N,M);
		dw		= new Matrix("dw",N,M);

		tmp=new double[M];

		Shell.pop();

		// initialise parameters

		// weights...
		w.set(new Constant(1.0/N));
		w.changed();

		// means...
		for (int i=0; i<N; i++) {
			mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i));
		}
		mu.changed();

		// sigmas
		sig.set(new Constant(1));
		sig.changed();


		// ----- initialise private bits ----------
		_s=s.array();
		_l=l.array();
		_t=t.getArray();
		_w=w.getArray();
		_mu=mu.getArray();
		_sig=sig.getArray();
		_post=post.getArray();
	}

	// normalise column sum of w, return array of sums
	private void normalise(double [][] w, double[] sum) {
		// make sure weights are normalised properly
		Mathx.zero(sum);
		for (int i=0; i<N; i++) Mathx.add(sum,w[i]);
		for (int i=0; i<N; i++) Mathx.div(w[i],sum);
	}

	public VFunction getPDF() { 
		PDF		 fn=new PDF();
		VFunction vfn=new VFunction("pdf",fn);
		fn.setupObservers(vfn);
		return vfn; 
	}

	public void run() 
	{
		// matlab equivalent code:
		// t = (x - mu)./sig
		// post = (w./sig).*exp(-0.5*t.^2)
		// [dummy, s] = max(post,2)
		// Z = sum(post,2)
		// post = post./(Z*ones(1,N))
		// Z=Z*Q;

		{
			// Vec.Iterator it=input.iterator();

			double [] x=input.array();

			for (int i=0; i<N; i++) {					
				for (int j=0; j<M; j++) {
					_t[i][j] = (x[j]-_mu[i][j])/_sig[i][j];
					_post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]);
				}
			}

			// this computes partition function 
			// and normalises posterior in one go
			normalise(_post,_l);

			// get MAP state
			for (int j=0; j<M; j++) {
				int		state=0;	
				double	pmax=_post[0][j];
				
				for (int i=1; i<N; i++) {
					if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; }
				}
				_s[j]=state;
			}

			for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]);

			L.value = Mathx.sum(_l);
		}

		L.changed();
		t.changed();
		s.changed();
		l.changed();
		post.changed();
	}

	public Task learnTask() {
		return new AnonymousTask() {
			// buffer changes to parameters
			double [][] _dw=dw.getArray();
			double [][] _dmu=dmu.getArray();
			double [][] _dsig=dsig.getArray();

			public void starting() { 
				dw.zero();	
				dmu.zero();	
				dsig.zero();
				normalise(w.getArray(),tmp);
			}

			public void run() {
				for (int i=0; i<N; i++) {
					for (int j=0; j<M; j++) {
						double pp=_post[i][j], tt=_t[i][j];
						_dw[i][j]  += pp;
						_dmu[i][j] += pp*tt;
						_dsig[i][j]+= pp*(tt*tt-1);
					}
				}
			}
		};
	}

	public Task flushTask() { 
		Shell.push(node);

		try {
			return new AnonymousTask() {

				VDouble nuw=new VDouble("weights.learn.rate",.001);
				VDouble numu=new VDouble("means.learn.rate",.001);
				VDouble nusig=new VDouble("sigmas.learn.rate",.001);

				double [][] _w=w.getArray();
				double [][] _dw=dw.getArray();

				public void run()
				{
					// mu += numu*dmu.*sig
					// sig += nusig*dsig.*sig;
					// lambda = sum(w.*dw,2)./sum(w.*w,2)
					//	dw -= (lambda*ones(1,N)) .* w
					// w *= exp(nu*dw)

					dmu.arrayTimesEquals(sig);
					dmu.timesEquals(nusig.value);
					mu.plusEquals(dmu);
					mu.changed();
					dmu.zero();

					dsig.arrayTimesEquals(sig);
					dsig.timesEquals(nusig.value);
					sig.plusEquals(dsig);
					sig.changed();
					dsig.zero();

					{
						// the effect of this is to project
						// dw away from w. the resulting vector
						// is then added to the log of w

						double nu=nuw.value;

						for (int j=0; j<M; j++) {
							double w2=0, lambda=0;
							for (int i=0; i<N; i++) {
								lambda += _w[i][j]*_dw[i][j];
								w2     += _w[i][j]*_w[i][j];
							}
							tmp[j]=lambda/w2;
						}

						for (int i=0; i<N; i++) {
							for (int j=0; j<M; j++) {
								_w[i][j]  *= Math.exp(
									nu*(_dw[i][j] - tmp[j]*_w[i][j])
								);  // update w
							}
						}
					}

					normalise(w.getArray(),tmp);
					w.changed();
					dw.zero();	
				}
			};
		} finally { Shell.pop(); }
	}

	class PDF extends Function implements Observer {
		VInteger		index;
		Viewable		vbl=null;

		public PDF() { 
			index=new VInteger("index",M/2); 
			index.setRange(0,M-1); 
			index.addObserver(this);
		}

		public void dispose() { index.dispose(); }

		public double apply(double x) {
			double Z=0;
			for (int i=0; i<N; i++) {
				int	 j=index.value;
				double t=(x-_mu[i][j])/_sig[i][j];
				Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t);
			}
			return Z;
		}
		public String format(String x) {
			return "mogpdf("+x+")";
		}

		public void setupObservers(Viewable v) {
			w.addObserver(this);
			mu.addObserver(this);
			sig.addObserver(this);
			vbl=v;
		}

		public void update(Observable o, Object arg) {
			if (arg!=Viewable.DISPOSING) vbl.changed();
		}
	}
}