view src/samer/models/notyet/PCA.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2002, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.Agent.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;
import  samer.maths.random.*;



public abstract class PCA extends NamedTask implements Model
{
	int				n;
	Vec			x;
	VVector		s;
	Matrix			W;
	VDouble		logA;
	double[]		_g,_s; // cached dE/dx
	Model			sourceModel;

	MatrixTimesVector					infer;
	MatrixTransposeTimesVector	grad;

	public PCA(Vec input) { this(input.size()); setInput(input); }
	public PCA(int N)
	{
		super("pca");
		Shell.push(node);

		n = N;

		x = null;
		W = new Matrix("W",n,n);
		s = new VVector("s",n);
		logA = new VDouble( "log |A|",0.0,VDouble.SIGNAL);
		infer = null;
		Shell.pop();

		grad = new MatrixTransposeTimesVector(W);
		_g = new double[n];
		_s = s.array();
	}
	
	public int getSize() { return n; }
	public VVector output() { return s; }
	public void setInput(Vec input) {
		if (input.size()!=n) throw new Error("Input vector is the wrong size");
		x = input;
		infer = new MatrixTimesVector(s,W,x);
	}

	public void dispose()
	{
		W.dispose();
		logA.dispose();
		s.dispose();
		infer.dispose();
		grad.dispose();

		super.dispose();
	}

	public void infer() { infer.run(); s.changed(); } // compute s=Wx
	public void compute() {
		grad.apply(_g,_s);
	}

	public double getEnergy() { return logA.value + sourceModel.getEnergy(); }
	public double[] getGradient() { return _g; }

	public void starting() {}
	public void stopping() {}
	public void run() { infer(); }

	public Task getTrainingTask() { return new Trainer(); }

	public class Trainer extends NamedTask
	{
		Matrix			G;
		double[][]		_G, _W;
		double[]         _g,_s, buf;
		VDouble		rate;
		int				batch, _n;

		public Trainer()
		{
			super("learn",PCA.this.getNode());
			Shell.push(Trainer.this.node);
			_n=n;
			G=new Matrix("G",n,n);
			rate=new VDouble("rate",0.01);
			Shell.pop();
			batch=0;

			_s=s.array();
			_G=G.getArray();
			_W=W.getArray();
			buf=new double[n]; // general purpose n-buffer
		}

		public void dispose() { G.dispose(); rate.dispose();	super.dispose(); }
		public void starting() { G.zero(); G.changed(); batch=0; }
		public void run()
		{
			double q, p[];
			batch++;

			double[] phi=sourceModel.getGradient();
			for (int i=0; i<_n; i++) {
				p = _G[i]; q=phi[i];
				for (int j=0; j<_n; j++) p[j] += q*_s[j];
				p[i] -= 1;
			}
		}

		public final void updateG() { G.changed(); }
		public void flush()
		{
			double eta=-rate.value/batch;

			// this is going to do a matrix G *= W, in place
			for (int i=0; i<_n; i++) {
				for (int j=0; j<_n; j++) {
					double a=0;
					for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
					buf[j] = a;
				}
				Mathx.copy(buf,_G[i]);
			}

			// now W += eta*G
			for (int i=0; i<_n; i++) {
				double [] p = _W[i],	q = _G[i];
				for (int j=0; j<n; j++) p[j] += eta*q[j];
			}

			// reset for next batch
			W.changed();
			G.zero();
			batch=0;
		}
	}
}