view src/samer/models/SparseICA.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2002, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.Agent.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;
import  java.util.*;

public class SparseICA extends NamedTask implements Model, Agent
{
	Model			sourceModel;
	int					n;
	Vec				x;
	SparseMatrix	W; // , A;
	VVector  		s;				// state
	VDouble			logA;
	double[]			_g; // cached dE/dx

	public SparseICA(Vec input) { this(input.size()); setInput(input); }
	public SparseICA(int N)
	{
		super("ica");
		Shell.push(node);

		n = N;

		x = null;
		W = new SparseMatrix("W");
		// A = new SparseMatrix("A",n,n);
		s = new VVector("s",n);
		logA = new VDouble( "log |A|");
		Shell.pop();

		W.identity(); W.changed();
		_g = new double[n];
	}

	public int getSize() { return n; }
	public VVector output() { return s; }
	public Model getOutputModel() { return sourceModel; }
	public void setOutputModel(Model m) {	sourceModel=m; }
	public void setInput(Vec input) {
		if (input.size()!=n) throw new Error("Input vector is the wrong size");
		x = input;
	}

	public SparseMatrix getWeightMatrix() { return W; }
	// public Matrix getBasisMatrix() { return A; }

	public void dispose()
	{
		W.dispose(); // A.dispose();
		logA.dispose();
		s.dispose();
		// infer.dispose();
		// grad.dispose();

		super.dispose();
	}

	public void infer() { W.times(s.array(),x.array()); s.changed(); }
	public void compute() { W.transposeTimes(_g,sourceModel.getGradient()); }
	public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
	public double[] getGradient() { return _g; }
	public samer.maths.opt.Functionx functionx() { return null; }

	public void starting() { }
	public void stopping() {}
	public void run() { infer(); }

	public void getCommands(Registry r) { r.add("basis").add("logdet"); }
	public void execute(String cmd, Environment env)
	{
	/*
		if (cmd.equals("basis")) {
			Shell.print("computing ICA basis...");
			// sparse matrix inverse
			// much easier if matrix is block decomposable!
			A.assign(W.inverse());
			A.changed();
			Shell.print("...done.");
		} else if (cmd.equals("logdet")) {
			Shell.print("computing SVD...");
			// double [] s=W.svd().getSingularValues();
			// Shell.print("...done.");
			// Mathx.log(s);
			// logA.set(-Mathx.sum(s));
		}
	 */
	}

	public Trainer getTrainer() { return new ON2Trainer(); }
	public Trainer getAltTrainer() { return new ON3Trainer(); }

	/** This trainer uses an O(N^2) run step and an O(N^2) flush. */

	public class ON2Trainer implements Model.Trainer
	{
		double[]         buf;
		VDouble		rate;
		SparseMatrix	GW;
		int					_n;
		double			thresh;
		double			batch;

		public ON2Trainer()
		{
			_n=n;
			Shell.push(getNode());
			GW=new SparseMatrix("GW",W); // copy pattern from W
			rate=new VDouble("rate",0.01);
			thresh=Shell.getDouble("anomaly",n*20);
			Shell.pop();
			batch=0;

			buf=new double[n]; // general purpose n-buffer
		}

		public void dispose() { GW.dispose(); rate.dispose(); }
		public void reset() { GW.zero(); GW.changed(); batch=0; }
		public void oneshot() { accumulate(1); flush(); }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w) {
			batch+=w;

			W.transposeTimes(buf,s.array()); // buf = W'*s
			Mathx.mul(buf,w); 						// buf = w*W'*s
			GW.addOuterProduct(sourceModel.getGradient(),buf); // GW+=grad*buf'
			// could subtract w*W here
		}

		public void flush() 	{
			if (batch==0) return;
			W.icaUpdate(GW,-rate.value,batch); // W += eta(GW/batch - W)
			W.changed();
			GW.zero();
			batch=0;
		}
	}

	/** This trainer saves on an O(N^2) step during accumulation, at
		the expense of an O(N^3) flush. As long as the batch size
		is O(N), then it should be about the same overall. The advantage
		is the collected statistics are more transparent, and can be used
		to make scalar or diagonal updates more frequenty. */


	public class ON3Trainer extends AnonymousTask implements Model.Trainer
	{
		SparseMatrix	G;
		double[]     	buf;
		VDouble		rate;
		int					_n;
		double			batch, thresh;

		public ON3Trainer()
		{
			_n=n;
			G=new SparseMatrix("G");
			// set up G so that it has the right links
			rate=new VDouble("rate",0.01);
			thresh=Shell.getDouble("anomaly",20*n);
			batch=0;

			buf=new double[n]; // general purpose n-buffer
		}

		/** this is so you can manipulate the matrix before flushing */
		public SparseMatrix getGMatrix() { return G; }

		public void starting() { reset(); }
		public void run() { accumulate(); }

		public void dispose() { G.dispose(); rate.dispose();	super.dispose(); }
		public void oneshot() { accumulate(1); flush(); }
		public void reset() { G.zero(); batch=0; }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w) {
			double _s[]=s.array();
			batch+=w;

			Mathx.copy(sourceModel.getGradient(),buf);
			Mathx.mul(buf,w);
			G.addOuterProduct(buf,_s);
			// subtract identity to prevent loss of precision?
		}

		public void flush()
		{
			if (batch==0) return;

			// compute deltas: dW=W - G*W/batch;
			// now W += eta*dW
			// double eta=rate.value;
			// for (each link) W[i][j] += eta*dW[i][j];
			reset(); // ready for next batch
		}
	}

	// Initialisation methods
	
	/** x is an N-dim vector.
		R is an N by N symmetric similarity matrix.
		M is number of unit pairs (butterflies)  build into ICA model */
	
	public void init1(Matrix R, int M) {
		double [][]	_R=R.getArray();
		TreeSet		edges=new TreeSet();

		Shell.status("Building sorted edge list...");
		for (int i=0; i<n; i++) {
			for (int j=i+1; j<n; j++) {
				edges.add(new Edge(i,j,_R[i][j]));
			}
		}
		Shell.status("Edge list complete.");
		Shell.status("Adding edges to sparse matrix...");

		W.allocate(n+2*M);
		for (int i=0; i<n; i++) W.addElement(i,i,1);

		Iterator it=edges.iterator();
		for (int k=0; k<M; k++) {
			Edge e=(Edge)it.next();
			W.addElement(e.i,e.j,0);
			W.addElement(e.j,e.i,0);
		}
		Shell.status("Sparse matrix complete.");
	}

	/** This version only builds disjoint butterfly pairs, ie only 2N edges per layer
		Need a version that builds bigger local modules. */
	public void init2(Matrix R) {
		double [][]	_R=R.getArray();
		boolean []	flags=new boolean[n];
		TreeSet		edges=new TreeSet();

		Shell.status("Building sorted edge list...");
		for (int i=0; i<n; i++) {
			for (int j=i+1; j<n; j++) {
				edges.add(new Edge(i,j,_R[i][j]));
			}
		}
		Shell.status("Edge list complete.");
		Shell.status("Adding edges to sparse matrix...");

		W.allocate(2*n);
		Iterator it=edges.iterator();
		for (int k=0; k<n; k+=2)  {
			Edge e;
			do { e=(Edge)it.next(); } while (flags[e.i] || flags[e.j]);
			W.addElement(e.i,e.i,1);
			W.addElement(e.i,e.j,0);
			W.addElement(e.j,e.i,0);
			W.addElement(e.j,e.j,1);
			flags[e.i]=flags[e.j]=true;
		}
		Shell.status("Sparse matrix complete.");
	}
}

class Edge implements Comparable {
	int	i, j;
	double x;

	public Edge(int i, int j, double x) { this.i=i; this.j=j; this.x=x; }
	public int compareTo(Object o) {
		Edge e=(Edge)o;
		// NB: REVERSE ordering on x
		if (x>e.x) return -1;
		else if (x<e.x) return 1;
		else if (i<e.i) return -1;
		else if (i>e.i) return 1;
		else if (j<e.j) return -1;
		else if (j>e.j) return 1;
		else return 0;
	}
	public String toString() {
		StringBuffer buf=new StringBuffer("(");
		buf.append(i);
		buf.append(",");
		buf.append(j);
		buf.append(":");
		buf.append(x);
		return buf.toString();
	}
}