diff src/samer/models/NoisyICA.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/samer/models/NoisyICA.java	Tue Jan 17 17:50:20 2012 +0000
@@ -0,0 +1,191 @@
+/*
+ *	Copyright (c) 2000, Samer Abdallah, King's College London.
+ *	All rights reserved.
+ *
+ *	This software is provided AS iS and WITHOUT ANY WARRANTY;
+ *	without even the implied warranty of MERCHANTABILITY or
+ *	FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+package samer.models;
+import  samer.core.*;
+import  samer.core.types.*;
+import  samer.tools.*;
+import  samer.maths.*;
+import  samer.maths.*;
+import  samer.maths.opt.*;
+
+public class NoisyICA extends NamedTask implements Model
+{
+	Vec				x;				// data (input)
+	int				n, m;			// sizes (data, sources)
+	Matrix		A;				// basis matrix
+	Model		Ms, Me;	// source and noise models
+	VVector	s, z, e;		// sources, reconstruction, error
+	VDouble		E;				// energy
+
+	Task			inference=new NullTask();
+
+
+	// ----- variables used in computations -----------------
+
+	private	double []	e_;
+	private	double []	x_;
+	private	double []	z_;
+	private	double []	s_;
+	private	VectorFunctionOfVector tA, tAt;
+
+
+	public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); }
+	public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); }
+	public NoisyICA(Node node, int inputs, int outputs)
+	{
+		super(node);
+		Shell.push(node);
+
+		n = inputs;
+		m = outputs;
+
+		s = new VVector("s",m);
+		z = new VVector("z",n);
+		e = new VVector("e",n);
+		E = new VDouble("E");
+		A = new Matrix("A",n,m);
+		A.identity();
+
+		e_ = e.array();
+		z_ = z.array();
+		s_ = s.array();
+		tA = new MatrixTimesVector(A);
+		tAt= new MatrixTransposeTimesVector(A);
+
+		Shell.pop();
+	}
+
+	public Model getSourceModel() { return Ms; }
+	public Model getNoiseModel() { return Me; }
+	public void  setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
+	public void  setNoiseModel(Model m) { Me=m; Me.setInput(e); }
+	public void  setInput(Vec in) { x=in; x_ = x.array(); }
+	public Matrix basisMatrix() { return A; }
+	public VVector output() { return s; }
+	public VVector error() { return e; }
+	public VVector reconstruction() { return z; }
+
+	public int  getSize() { return n; }
+
+	public void setInferenceTask(Task t) { inference=t; }
+
+	public void infer() {
+		try {	inference.run(); }
+		catch (Exception ex) { 
+			Shell.trace("error: "+ex);
+			ex.printStackTrace();
+			throw new Error("inference failed: "+ex); }
+		tA.apply(s_,z_); Mathx.sub(e_,x_,z_);
+		e.changed(); z.changed(); s.changed();
+	}
+
+	public void compute() {
+		E.set(Me.getEnergy() + Ms.getEnergy());
+		// what about dE/dx?
+	}
+	public double getEnergy() { return E.value; }
+	public double [] getGradient() { return null; } // this is wrong
+
+	/** get basis vector norms into given array */
+	public void norms(double [] na)
+	{
+		double [][] M=A.getArray();
+
+		Mathx.zero(na);
+		for (int i=0; i<n; i++) {
+			double [] Mi=M[i];
+			for (int j=0; j<m; j++) {
+				na[j] += Mi[j]*Mi[j];
+			}
+		}
+	}
+
+	public void run() { infer(); }
+	public void dispose()
+	{
+		s.dispose();
+		A.dispose();
+		e.dispose();
+		z.dispose();
+		E.dispose();
+		tA.dispose();
+		tAt.dispose();
+
+		super.dispose();
+	}
+
+	public Functionx functionx() { return null; }
+
+	public Functionx posterior()	{
+		// returns Functionx which evaluates E and dE/ds at current x
+		return new Functionx() {
+			Functionx fMs=Ms.functionx();
+			Functionx fMe=Me.functionx();
+			double [] e=new double[n];
+			double [] ge=new double[n];
+			double [] gs=new double[m];
+
+			public void dispose() { fMs.dispose(); fMe.dispose(); }
+			public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
+			public double evaluate( double [] s, double [] g)
+			{
+				tA.apply(s,z_); Mathx.sub(e,x_,z_);
+				double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e)
+				double Es=fMs.evaluate(s,g);	Mathx.sub(g,gs); // g=gamma(s)-ge
+				return Es+Ee;
+			}
+		};
+	}
+
+	public Trainer learnHebbian() 	{ return new MatrixTrainer(e_,A,s_); }
+
+	public Trainer learnLewickiSejnowski()
+	{
+		final double [] h =new double[n];
+		final double [] f =new double[m];
+		return new MatrixTrainer(h,A,s_) {
+
+			public void accumulate(double w) {
+				// tAt.apply(Me.getGradient(),f); tA.apply(f,h);
+				tA.apply(Ms.getGradient(),h);
+				super.accumulate(w);
+			}
+			public void flush() { // flush with decay
+				for (int j=0; j<this.m; j++)
+					for (int i=0; i<this.n; i++)
+						_T[i][j] -= count*_A[i][j];
+				super.flush();
+			}
+		};
+	}
+
+	public Trainer learnDecayWhenActive()
+	{
+		final double [] h =new double[n];
+		return new MatrixTrainer(h,A,s_) {
+			VDouble				th=new VDouble("threshold",0.01);
+			public void accumulate(double w) {
+				// perhaps would like to get info out of optimiser here
+				tA.apply(Ms.getGradient(),h);
+				super.accumulate(w);
+
+				// decay when active part
+				double thresh=th.value;
+				for (int j=0; j<this.m; j++)
+					if (isActive(s_[j],thresh))
+						for (int i=0; i<this.n; i++)
+							_T[i][j] -= _A[i][j];
+			}
+
+			public VDouble getThreshold() { return th; }
+		};
+	}
+	static boolean isActive(double s, double t) { return s>t || s<-t; }
+}