view src/samer/models/ICA.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2002, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.Agent.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;
import  samer.maths.opt.*;

public class ICA extends Viewable implements SafeTask, Model, Agent
{
	Model			sourceModel;
	int				n;
	Vec			x;
	Matrix			W, A;
	VVector  		s;				// state
	VDouble		logA;
	double[]		_g; // cached dE/dx

	MatrixTimesVector							infer;
	MatrixTransposeTimesVector	grad;

	public ICA(Vec input) { this(input.size()); setInput(input); }
	public ICA(int N) { this(new Node("ica"),N); }
	public ICA(Node node, int N)
	{
		super(node);
		Shell.push(node);

		n = N;

		x = null;
		W = new Matrix("W",n,n);
		A = new Matrix("A",n,n); 	// should we defer creation of this?
		s = new VVector("s",n);
		logA = new VDouble( "log |A|");
		infer = null;
		Shell.pop();

		W.identity(); W.changed();
		grad = new MatrixTransposeTimesVector(W);
		_g = new double[n];

		setAgent(this);
	}

	public int getSize() { return n; }
	public VVector output() { return s; }
	public Model getOutputModel() { return sourceModel; }
	public void setOutputModel(Model m) {	sourceModel=m; }
	public void setInput(Vec input) {
		if (input.size()!=n) throw new Error("Input vector is the wrong size");
		x = input; infer = new MatrixTimesVector(s,W,x);
	}

	public Matrix getWeightMatrix() { return W; }
	public Matrix getBasisMatrix() { return A; }

	public void dispose()
	{
		W.dispose(); A.dispose();
		logA.dispose();
		s.dispose();
		infer.dispose();
		grad.dispose();

		super.dispose();
	}

	public void infer() {	// compute s=Wx
		infer.run();
		s.changed();
	}

	public void compute() {
		grad.apply(sourceModel.getGradient(),_g);
	}

	public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
	public double[] getGradient() { return _g; }

	public Functionx functionx() {
		return new Functionx() {
			Functionx fM=sourceModel.functionx();
			double [] s=ICA.this.s.array(); //new double[n];
			double [] gs=new double[n];

			public void dispose() { fM.dispose();  }
			public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
			public double evaluate(double [] x, double [] g) {
				infer.apply(x,s); double E=fM.evaluate(s,gs);
				grad.apply(gs,g); return E+logA.value;
			}
		};
	}

	public void starting() {}
	public void stopping() {}
	public void run() { infer(); }

	public void getCommands(Registry r) { r.add("basis").add("logdet"); }
	public void execute(String cmd, Environment env)
	{
		if (cmd.equals("basis")) {
			Shell.print("computing ICA basis...");
			A.assign(W.inverse());
			A.changed();
			Shell.print("...done.");
		} else if (cmd.equals("logdet")) {
			logA.set(-W.logdet());
		}
	}

	public Trainer getTrainer() { return new ON2Trainer(); }
	public Trainer getAltTrainer() { return new ON3Trainer(); }
	public Trainer getDecayWhenActiveTrainer() { return new ON2DecayWhenActive(); }

	/** This trainer uses an O(N^2) run step and an O(N^2) flush. */
	
	public class ON2Trainer implements Model.Trainer
	{
		double[][]		_W, _GW;
		double[]         _s, buf;
		VDouble			rate;
		Matrix				GW;
		int					_n;
		double			thresh, batch;
		MatrixTransposeTimesVector sW;

		public ON2Trainer()
		{
			_n=n;
			GW=new Matrix("GW",n,n);
			rate=new VDouble("rate",0.01);
			thresh=Shell.getDouble("anomaly",n*20);
			batch=0;

			_s=s.array();
			_GW=GW.getArray();
			_W=W.getArray();
			buf=new double[n]; // general purpose n-buffer
			sW = new MatrixTransposeTimesVector(buf,W,_s);
		}

		public String toString() { return "ON2Trainer:"+ICA.this; }
		public void dispose() { GW.dispose(); rate.dispose(); }
		public void reset() { GW.zero(); GW.changed(); batch=0; }
		public void oneshot() { accumulate(1); flush(); }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w)
		{
			// HACK!
			if (sourceModel.getEnergy()>thresh) return;

			double q, p[];
			batch+=w;

			sW.run(); // buf = W'*s
			double[] phi=sourceModel.getGradient();
			for (int i=0; i<_n; i++) {
				double [] r=_W[i];
//			p = _GW[i]; q=phi[i];
//			for (int j=0; j<_n; j++) p[j] += w*(q*buf[j] - r[j]);
				p = _GW[i]; q=w*phi[i];
				for (int j=0; j<_n; j++) p[j] += q*buf[j];
			}
		}

		public void flush()
		{
			if (batch==0) return;
			double eta=-rate.value/batch;

			// now W += eta*GW
			for (int i=0; i<_n; i++) {
				double [] p = _W[i],	q = _GW[i];
//			for (int j=0; j<_n; j++) p[j] += eta*q[j];
				for (int j=0; j<_n; j++) p[j] += eta*(q[j]-batch*p[j]);
//			Mathx.mul(q,1.0/batch);
			}
			GW.changed();

			// reset for next batch
			W.changed();  GW.zero();
			batch=0;
		}
	}

	public class ON2DecayWhenActive extends ON2Trainer {
		VDouble th=new VDouble("thresh",0.0);

		public String toString() { return "DecayWhenActiveTrainer:"+ICA.this; }
		public void accumulate(double w) {
			if (sourceModel.getEnergy()>thresh) return;

			double q, p[];
			batch+=w;

			sW.run(); // buf = W'*s
			double[] phi=sourceModel.getGradient();
			for (int i=0; i<_n; i++) {
				double [] r=_W[i];
				p = _GW[i]; q=w*phi[i];
				for (int j=0; j<_n; j++) p[j] += q*buf[j];
			}

			// decay when active part
			double thresh=th.value;
			for (int j=0; j<_n; j++)
				if (isActive(_s[j],thresh))
					for (int i=0; i<_n; i++)
						_GW[i][j] -= _W[i][j];
		}

		public void flush()
		{
			if (batch==0) return;
			double eta=-rate.value/batch;

			// now W += eta*GW
			for (int i=0; i<_n; i++) {
				double [] p = _W[i],	q = _GW[i];
				for (int j=0; j<_n; j++) p[j] += eta*q[j];
			}
			GW.changed();
			W.changed();  GW.zero();
			batch=0;
		}
	}
	static boolean isActive(double s, double t) { return s>=t || s<=-t; }

	/** This trainer saves on an O(N^2) step during accumulation, at
		the expense of an O(N^3) flush. As long as the batch size
		is O(N), then it should be about the same overall. The advantage
		is the collected statistics are more transparent, and can be used
		to make scalar or diagonal updates more frequenty. */


	public class ON3Trainer extends AnonymousTask implements Model.Trainer
	{
		Matrix			G;
		double[][]	_G, _W;
		double[]     _s, buf;
		VDouble		rate;
		int				_n;
		double		batch, thresh;

		public ON3Trainer()
		{
			_n=n;
			G=new Matrix("G",n,n);
			rate=new VDouble("rate",0.01);
			thresh=Shell.getDouble("anomaly",20*n);
			batch=0;

			_s=s.array();
			_G=G.getArray();
			_W=W.getArray();
			buf=new double[n]; // general purpose n-buffer
		}

		public String toString() { return "ON3Trainer:"+ICA.this; }

		/** this is so you can manipulate the matrix before flushing */
		public Matrix getGMatrix() { return G; }

		public void starting() { reset(); }
		public void run() { accumulate(); }

		public void dispose() { G.dispose(); rate.dispose();	super.dispose(); }
		public void oneshot() { accumulate(1); flush(); }
		public void reset() { G.zero(); batch=0; }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w)
		{
			// HACK!
			if (sourceModel.getEnergy()>thresh) return; 

			double q, p[];
			batch+=w;

			double[] phi=sourceModel.getGradient();
			for (int i=0; i<_n; i++) {
				p = _G[i]; q=w*phi[i];
//				if (Double.isNaN(q)) throw new Error("NAN"+i);
				for (int j=0; j<_n; j++) p[j] += q*_s[j];
				p[i] -= w;
			}
		}

		public void flush()
		{
			if (batch==0) return;
			double eta=-rate.value/batch;

			G.changed();
			
			// this is going to do a matrix G *= W, in place
			for (int i=0; i<_n; i++) {
				for (int j=0; j<_n; j++) {
					double a=0;
					for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
					buf[j] = a;
				}
				Mathx.copy(buf,_G[i]);
			}

			// now W += eta*G
			for (int i=0; i<_n; i++) {
				double [] p = _W[i],	q = _G[i];
				for (int j=0; j<n; j++) p[j] += eta*q[j];
			}

			reset(); // ready for next batch
		}
	}

	// See Hyvarinen's paper that Nick gave me. Not finished
	public class NewtonTrainer extends ON3Trainer
	{
		Function		dgamma;
		VVector		f;			// scnd derivatives of log prior
		double[]     _f;

		public NewtonTrainer(Function dg)
		{
			dgamma=dg;
			f=new VVector("f",n);
			_f=f.array();
		}

		public void dispose() { f.dispose(); dgamma.dispose();	super.dispose(); }
		public void reset() { Mathx.zero(_f); super.reset(); }
		public void accumulate(double w)
		{
			// HACK!
			if (sourceModel.getEnergy()>thresh) return;

			dgamma.apply(_s,buf);
			Mathx.add(_f,buf);
			super.accumulate(w);
		}

		public void flush()
		{
			if (batch==0) return;

			// first do some things to G, then do normal flush

			super.flush();
		}
	}
}