diff src/samer/models/Mixture.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/samer/models/Mixture.java	Tue Jan 17 17:50:20 2012 +0000
@@ -0,0 +1,153 @@
+/*
+ *	Copyright (c) 2002, Samer Abdallah, King's College London.
+ *	All rights reserved.
+ *
+ *	This software is provided AS iS and WITHOUT ANY WARRANTY;
+ *	without even the implied warranty of MERCHANTABILITY or
+ *	FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+package samer.models;
+
+import samer.core.*;
+import samer.core.types.*;
+import samer.maths.*;
+import samer.maths.opt.*;
+import samer.tools.*;
+
+
+public class Mixture extends NamedTask implements Model
+{
+	private Model 		M[]; 	// models
+	private int			n, m; // size of vector, num models
+	private Vec			x;		// input
+	private VVector	w;		// prior weights
+	private VVector	s;		// posterior
+	private int			k; 	// MAP estimate
+	private VDouble	Z;		// Parition function, ie p(x)
+	private double[]  _x,_s,_w,_g;
+
+	public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); }
+	public Mixture( int N, int L)
+	{
+		super("mixture");
+		Shell.push(node);
+
+		n = N;
+		m = L;
+
+		x = null;
+		w = new VVector("prior",m);
+		s = new VVector("posterior",m);
+		Z = new VDouble("Z");
+		M = new Model[m];
+		Shell.pop();
+	
+		_s=s.array();
+		_w=w.array();
+		_g=new double[n];
+		Mathx.set(_w,new Constant(1.0/L));
+	}
+
+	public VVector prior() 		{ return w; }
+	public VVector posterior() { return s; }
+	public void setModel(int i, Model m) { M[i]=m;	}
+	public void setInput(Vec in) { x=in; _x=x.array(); }
+	public int getSize() { return n; }
+
+	public void dispose()
+	{
+		s.dispose();
+		w.dispose();
+		Z.dispose();
+		for (int i=0; i<m; i++) M[i].dispose();
+		super.dispose();
+	}
+
+	public void infer() {
+		// get models to compute energies.
+		// for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); }
+
+		// compute relative posterior
+		for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies
+		double Emin=Mathx.min(_s);
+		for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]);
+
+		// compute partition function, normalise posterior
+		Z.set(Mathx.sum(_s));	// compute parition function
+		k=Mathx.argmax(_s);	// get MAP model
+		Mathx.mul(_s,1/Z.value);	// normalise posterior
+		s.changed();
+	}
+
+	public void compute()
+	{
+		/* compute gradients weighted by posterior */
+		Mathx.zero(_g);
+		for (int i=0; i<m; i++) {
+			double [] phi = M[i].getGradient();
+			for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j];
+		}
+	}
+
+	public double	getEnergy() { return -Math.log(Z.value); }
+	public double [] getGradient() { return _g; }
+
+	public Functionx functionx() { return null; }
+
+	public void run() { infer(); }
+
+	public Trainer getTrainer() { return new Trainer(); }
+
+	public class Trainer implements Model.Trainer
+	{
+		Model.Trainer	T[];
+		VDouble		rate;
+		VVector		dw;
+		double		batch, _dw[];
+
+		public Trainer() {
+			T=new Model.Trainer[m]; // should all be null
+			rate=new VDouble("rate",0.001);
+			dw=new VVector("dw",m);
+			_dw=dw.array();
+		}
+
+		public void setTrainer(int i,Model.Trainer t) { T[i]=t; }
+		public void dispose() { rate.dispose(); dw.dispose(); }
+
+		public void accumulate() { accumulate(1.0); }
+		public void accumulate(double w) {
+			batch+=w;
+			for (int i=0;i<m; i++) {
+				if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet
+			}
+
+			// now accumulate info about priors
+			Mathx.add(_dw,_s);
+		}
+
+		public void oneshot() { accumulate(1.0); flush(); }
+		public void flush() {
+			for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush();
+         double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w);
+			double nu=rate.value/batch;
+
+			dw.changed();
+			for (int i=0; i<m; i++) {
+				_w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w
+			}
+			Mathx.zero(_dw); batch=0;
+
+			// normalise
+			Mathx.mul(_w,1/Mathx.sum(_w));
+			w.changed();
+		}
+		public void reset() { 
+			for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset();
+			Mathx.zero(_dw);
+			batch=0;
+		}
+	}
+}
+