view src/samer/models/Mixture.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line source
/*
 *	Copyright (c) 2002, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;

import samer.core.*;
import samer.core.types.*;
import samer.maths.*;
import samer.maths.opt.*;
import samer.tools.*;


public class Mixture extends NamedTask implements Model
{
	private Model 		M[]; 	// models
	private int			n, m; // size of vector, num models
	private Vec			x;		// input
	private VVector	w;		// prior weights
	private VVector	s;		// posterior
	private int			k; 	// MAP estimate
	private VDouble	Z;		// Parition function, ie p(x)
	private double[]  _x,_s,_w,_g;

	public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); }
	public Mixture( int N, int L)
	{
		super("mixture");
		Shell.push(node);

		n = N;
		m = L;

		x = null;
		w = new VVector("prior",m);
		s = new VVector("posterior",m);
		Z = new VDouble("Z");
		M = new Model[m];
		Shell.pop();
	
		_s=s.array();
		_w=w.array();
		_g=new double[n];
		Mathx.set(_w,new Constant(1.0/L));
	}

	public VVector prior() 		{ return w; }
	public VVector posterior() { return s; }
	public void setModel(int i, Model m) { M[i]=m;	}
	public void setInput(Vec in) { x=in; _x=x.array(); }
	public int getSize() { return n; }

	public void dispose()
	{
		s.dispose();
		w.dispose();
		Z.dispose();
		for (int i=0; i<m; i++) M[i].dispose();
		super.dispose();
	}

	public void infer() {
		// get models to compute energies.
		// for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); }

		// compute relative posterior
		for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies
		double Emin=Mathx.min(_s);
		for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]);

		// compute partition function, normalise posterior
		Z.set(Mathx.sum(_s));	// compute parition function
		k=Mathx.argmax(_s);	// get MAP model
		Mathx.mul(_s,1/Z.value);	// normalise posterior
		s.changed();
	}

	public void compute()
	{
		/* compute gradients weighted by posterior */
		Mathx.zero(_g);
		for (int i=0; i<m; i++) {
			double [] phi = M[i].getGradient();
			for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j];
		}
	}

	public double	getEnergy() { return -Math.log(Z.value); }
	public double [] getGradient() { return _g; }

	public Functionx functionx() { return null; }

	public void run() { infer(); }

	public Trainer getTrainer() { return new Trainer(); }

	public class Trainer implements Model.Trainer
	{
		Model.Trainer	T[];
		VDouble		rate;
		VVector		dw;
		double		batch, _dw[];

		public Trainer() {
			T=new Model.Trainer[m]; // should all be null
			rate=new VDouble("rate",0.001);
			dw=new VVector("dw",m);
			_dw=dw.array();
		}

		public void setTrainer(int i,Model.Trainer t) { T[i]=t; }
		public void dispose() { rate.dispose(); dw.dispose(); }

		public void accumulate() { accumulate(1.0); }
		public void accumulate(double w) {
			batch+=w;
			for (int i=0;i<m; i++) {
				if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet
			}

			// now accumulate info about priors
			Mathx.add(_dw,_s);
		}

		public void oneshot() { accumulate(1.0); flush(); }
		public void flush() {
			for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush();
         double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w);
			double nu=rate.value/batch;

			dw.changed();
			for (int i=0; i<m; i++) {
				_w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w
			}
			Mathx.zero(_dw); batch=0;

			// normalise
			Mathx.mul(_w,1/Mathx.sum(_w));
			w.changed();
		}
		public void reset() { 
			for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset();
			Mathx.zero(_dw);
			batch=0;
		}
	}
}