view src/samer/models/VarianceICA.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line source
/*
 *	Copyright (c) 2000, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;
import  samer.maths.*;
import  samer.maths.opt.*;

public class VarianceICA extends NamedTask implements Model
{
	Vec		x;				// data (input)
	int		n, m;			// sizes (data, sources)
	Matrix	A;				// basis matrix
	Model		Ms;			// source model
	VVector	s, z, e;		// sources, reconstruction, error
	VDouble	E;				// energy

	Task			inference=new NullTask();


	// ----- variables used in computations -----------------

	private	double []	e_;
	private	double []	x_;
	private	double []	z_;
	private	double []	s_;
	private	VectorFunctionOfVector tA, tAt;


	public VarianceICA(Node node, int inputs, int outputs)
	{
		super(node);
		Shell.push(node);

		n = inputs;
		m = outputs;

		s = new VVector("s",m);
		z = new VVector("z",n);
		e = new VVector("e",n);
		E = new VDouble("E");
		A = new Matrix("A",n,m);
		A.identity();

		e_ = e.array();
		z_ = z.array();
		s_ = s.array();
		tA = new MatrixTimesVector(A);
		tAt= new MatrixTransposeTimesVector(A);

		Shell.pop();
	}

	public Model getSourceModel() { return Ms; }
	public void  setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
	public void  setInput(Vec in) { x=in; x_ = x.array(); }
	public Matrix basisMatrix() { return A; }
	public VVector output() { return s; }
	public VVector error() { return e; }

	public int  getSize() { return n; }

	public void setInferenceTask(Task t) { inference=t; }

	public void infer() {
		try {	inference.run(); }
		catch (Exception ex) { throw new Error("inference failed"); }
		tA.apply(s_,z_); 
		Mathx.sub(e_,x_,z_);
		Mathx.div(e_,z_);
		Mathx.div(e_,z_);
		e.changed(); z.changed(); s.changed();
	}

	public void compute() {
		double EE=0;
		for (int i=0; i<n; i++) {
			double t=x_[i]/z_[i];
			EE += t - Math.log(t);
		}
		E.set(EE - n + Ms.getEnergy());
		// what about dE/dx?
	}
	public double getEnergy() { return E.value; }
	public double [] getGradient() { return null; } // this is wrong

	public void run() { infer(); }
	public void dispose()
	{
		s.dispose();
		A.dispose();
		e.dispose();
		z.dispose();
		E.dispose();
		tA.dispose();
		tAt.dispose();

		super.dispose();
	}

	public Functionx functionx() { return null; }

	public Functionx posterior()	{
		// returns Functionx which evaluates E and dE/ds at current x
		return new Functionx() {
			Functionx fMs=Ms.functionx();
			double [] ge=new double[n];

			public void dispose() { fMs.dispose(); }
			public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
			public double evaluate( double [] s, double [] g)
			{
				double E=0;
				
				tA.apply(s,z_);	// z=As
				for (int i=0; i<n; i++) {
					double t=x_[i]/z_[i];
					g[i] = (t-1)/z_[i];
					E += t - Math.log(t);
				}
				tAt.apply(g,ge);  // ge=A'*g
				E+=fMs.evaluate(s,g);
				Mathx.sub(g,ge); 
				return E;
			}
		};
	}

	public Trainer learnHebbian() 	{
		Shell.push(node);
		try {	return new MatrixTrainer(e_,A,s_); }
		finally { Shell.pop(); }
	}

	public Trainer learnLewickiSejnowski()
	{
		Shell.push(node);
		try {
			final double [] h =new double[n];
			final double [] f =new double[m];
			return new MatrixTrainer(h,A,s_) {

				public void accumulate(double w) {
					tA.apply(Ms.getGradient(),h);
					super.accumulate(w);
				}
				public void flush() { // flush with decay
					for (int j=0; j<this.m; j++)
						for (int i=0; i<this.n; i++)
							_T[i][j] -= count*_A[i][j];
					super.flush();
				}
			};
		} finally { Shell.pop(); }
	}

	public Trainer learnDecayWhenActive()
	{
		Shell.push(node);
		try {
			final double [] h =new double[n];
			return new MatrixTrainer(h,A,s_) {
				VDouble				th=new VDouble("threshold",0.01);
				public void accumulate(double w) {
					tA.apply(Ms.getGradient(),h);
					super.accumulate(w);

					// decay when active part
					double thresh=th.value;
					for (int j=0; j<this.m; j++)
						if (isActive(s_[j],thresh))
							for (int i=0; i<this.n; i++)
								_T[i][j] -= _A[i][j];
				}
			};
		} finally { Shell.pop(); }
	}
	static boolean isActive(double s, double t) { return s>t || s<-t; }
}