view src/samer/models/MOGModel.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
package samer.models;

import	samer.core.*;
import	samer.core.types.*;
import	samer.maths.*;
import	samer.tools.*;

public class MOGModel extends NamedTask
{
	Generator input;

	int		N, state;			// num states
	VDouble	Z;

	VVector	vt, vpost;			// deviations, posterior
	VVector	vw, vmu, vsig;		// weights, means, and stddevs

	double[] t, post;				// deviations, posterior
	double[]	w, mu, sig;			// weights, means, and stddevs
	double[] dw, dmu, dsig;		// deltas
	VDouble	nuw, numu, nusig;	// learning rates
	boolean	lrnW, lrnMu, lrnSig;

	final static double Q=1/Math.sqrt(2*Math.PI);;

	public MOGModel(Generator in, int n)
	{
		super("mog");
		input=in;

		Shell.push(node);
		N=n;

		Z = new VDouble("likelihood");

		t		= new double[N];
		post	= new double[N];
		w		= new double[N];
		mu		= new double[N];
		sig	= new double[N];
		dw		= new double[N];
		dmu	= new double[N];
		dsig	= new double[N];

		vpost=new VVector("p",post);
		vw=new VVector("weights",w);
		vmu=new VVector("means",mu);
		vsig=new VVector("sigmas",sig);

		lrnW=Shell.getBoolean("weights.learn",true);
		lrnMu=Shell.getBoolean("means.learn",true);
		lrnSig=Shell.getBoolean("sigmas.learn",true);

		nuw=new VDouble("weights.learn.rate",.001);
		numu=new VDouble("means.learn.rate",.001);
		nusig=new VDouble("sigmas.learn.rate",.001);

//		nu=new VDouble("learn.rate",0.001);
		Shell.pop();

		// initialise parameters
		Generator rnd=new samer.maths.random.NormalisedGaussian();

		// weights...
		Mathx.zero(w);
		Mathx.add(w,1.0/N);
		vw.changed();

		// means...
		Mathx.set(mu,rnd);
		vmu.changed();

		// sigmas
		Mathx.zero(sig);
		Mathx.add(sig,1);
		vsig.changed();

	}

	public void run() { data(input.next()); }
	public void starting() { 
		Mathx.zero(dw);	
		Mathx.zero(dmu);	
		Mathx.zero(dsig);	

		// make sure weights are normalised properly
		Mathx.mul(w,1/Mathx.sum(w));
	}

	public Task flushTask() { 
		return new AnonymousTask() {
			public void run() { flush(); }
		};
	}

	public Function getPDF() 
	{
		return new Function() {
			public double apply(double x) {
				double Z=0;
				for (int i=0; i<N; i++) {
					t[i]=(x-mu[i])/sig[i];
					Z+=(w[i]/sig[i])*Math.exp(-0.5*t[i]*t[i]);
				}
				return Z*Q;
			}
			public String format(String x) {
				return "mogpdf("+x+")";
			}
		};
	}

	public void data(double x) 
	{
		// find posterior and MAP state
		state=0;	double Z=0;
		for (int i=0; i<N; i++) {
			t[i]=(x-mu[i])/sig[i];
			post[i] = (w[i]/sig[i])*Math.exp(-0.5*t[i]*t[i]);
			Z+=post[i];

			if (post[i]>post[state]) state=i;
		}
		Mathx.mul(post,1/Z);
		Z *= Q;

		if (lrnW||lrnMu||lrnSig) {
			// buffer changes to parameters
			for (int i=0; i<N; i++) {
				double p=post[i], tt=t[i];
				if (lrnW)		dw[i]  += p;
				if (lrnMu)	dmu[i] += p*tt;
				if (lrnSig)	dsig[i]+= p*(tt*tt-1);
			}
		}

		this.Z.value=Z; 
		this.Z.changed();
		vpost.changed();
	}

	public void flush()
	{
		if (lrnMu) {
			Mathx.mul(dmu,sig);
			flush(mu,dmu,numu.value);
			vmu.changed();
		}
		if (lrnSig)	{ 
			Mathx.mul(dsig,sig);
			flush(sig,dsig,nusig.value);
			vsig.changed();
		}
		if (lrnW)	{ 
/*
			// normalise
			Mathx.mul(w,1/Mathx.sum(w));

			Mathx.div(dw,w);
			double lambda=Mathx.sum(dw)/N;
			Mathx.sub(dw,lambda);
			flush(w,dw,nuw.value);
 */

			double lambda=Mathx.dot(w,dw)/Mathx.dot(w,w);
			double nu=nuw.value;

			for (int i=0; i<N; i++) {
				dw[i] -= lambda*w[i];			// project away from w
				w[i]  *= Math.exp(nu*dw[i]);  // update w
			}
			Mathx.zero(dw);	

			// normalise
			Mathx.mul(w,1/Mathx.sum(w));

			vw.changed();
		}

	}

	private void flush(double [] theta, double [] dtheta, double nu) 
	{
		Mathx.mul(dtheta,nu); // *this.nu.value); 
		Mathx.add(theta,dtheta); 
		Mathx.zero(dtheta);	
	}
}