view src/samer/models/NoisyICA.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2000, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;
import  samer.maths.*;
import  samer.maths.opt.*;

public class NoisyICA extends NamedTask implements Model
{
	Vec				x;				// data (input)
	int				n, m;			// sizes (data, sources)
	Matrix		A;				// basis matrix
	Model		Ms, Me;	// source and noise models
	VVector	s, z, e;		// sources, reconstruction, error
	VDouble		E;				// energy

	Task			inference=new NullTask();


	// ----- variables used in computations -----------------

	private	double []	e_;
	private	double []	x_;
	private	double []	z_;
	private	double []	s_;
	private	VectorFunctionOfVector tA, tAt;


	public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); }
	public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); }
	public NoisyICA(Node node, int inputs, int outputs)
	{
		super(node);
		Shell.push(node);

		n = inputs;
		m = outputs;

		s = new VVector("s",m);
		z = new VVector("z",n);
		e = new VVector("e",n);
		E = new VDouble("E");
		A = new Matrix("A",n,m);
		A.identity();

		e_ = e.array();
		z_ = z.array();
		s_ = s.array();
		tA = new MatrixTimesVector(A);
		tAt= new MatrixTransposeTimesVector(A);

		Shell.pop();
	}

	public Model getSourceModel() { return Ms; }
	public Model getNoiseModel() { return Me; }
	public void  setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
	public void  setNoiseModel(Model m) { Me=m; Me.setInput(e); }
	public void  setInput(Vec in) { x=in; x_ = x.array(); }
	public Matrix basisMatrix() { return A; }
	public VVector output() { return s; }
	public VVector error() { return e; }
	public VVector reconstruction() { return z; }

	public int  getSize() { return n; }

	public void setInferenceTask(Task t) { inference=t; }

	public void infer() {
		try {	inference.run(); }
		catch (Exception ex) { 
			Shell.trace("error: "+ex);
			ex.printStackTrace();
			throw new Error("inference failed: "+ex); }
		tA.apply(s_,z_); Mathx.sub(e_,x_,z_);
		e.changed(); z.changed(); s.changed();
	}

	public void compute() {
		E.set(Me.getEnergy() + Ms.getEnergy());
		// what about dE/dx?
	}
	public double getEnergy() { return E.value; }
	public double [] getGradient() { return null; } // this is wrong

	/** get basis vector norms into given array */
	public void norms(double [] na)
	{
		double [][] M=A.getArray();

		Mathx.zero(na);
		for (int i=0; i<n; i++) {
			double [] Mi=M[i];
			for (int j=0; j<m; j++) {
				na[j] += Mi[j]*Mi[j];
			}
		}
	}

	public void run() { infer(); }
	public void dispose()
	{
		s.dispose();
		A.dispose();
		e.dispose();
		z.dispose();
		E.dispose();
		tA.dispose();
		tAt.dispose();

		super.dispose();
	}

	public Functionx functionx() { return null; }

	public Functionx posterior()	{
		// returns Functionx which evaluates E and dE/ds at current x
		return new Functionx() {
			Functionx fMs=Ms.functionx();
			Functionx fMe=Me.functionx();
			double [] e=new double[n];
			double [] ge=new double[n];
			double [] gs=new double[m];

			public void dispose() { fMs.dispose(); fMe.dispose(); }
			public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
			public double evaluate( double [] s, double [] g)
			{
				tA.apply(s,z_); Mathx.sub(e,x_,z_);
				double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e)
				double Es=fMs.evaluate(s,g);	Mathx.sub(g,gs); // g=gamma(s)-ge
				return Es+Ee;
			}
		};
	}

	public Trainer learnHebbian() 	{ return new MatrixTrainer(e_,A,s_); }

	public Trainer learnLewickiSejnowski()
	{
		final double [] h =new double[n];
		final double [] f =new double[m];
		return new MatrixTrainer(h,A,s_) {

			public void accumulate(double w) {
				// tAt.apply(Me.getGradient(),f); tA.apply(f,h);
				tA.apply(Ms.getGradient(),h);
				super.accumulate(w);
			}
			public void flush() { // flush with decay
				for (int j=0; j<this.m; j++)
					for (int i=0; i<this.n; i++)
						_T[i][j] -= count*_A[i][j];
				super.flush();
			}
		};
	}

	public Trainer learnDecayWhenActive()
	{
		final double [] h =new double[n];
		return new MatrixTrainer(h,A,s_) {
			VDouble				th=new VDouble("threshold",0.01);
			public void accumulate(double w) {
				// perhaps would like to get info out of optimiser here
				tA.apply(Ms.getGradient(),h);
				super.accumulate(w);

				// decay when active part
				double thresh=th.value;
				for (int j=0; j<this.m; j++)
					if (isActive(s_[j],thresh))
						for (int i=0; i<this.n; i++)
							_T[i][j] -= _A[i][j];
			}

			public VDouble getThreshold() { return th; }
		};
	}
	static boolean isActive(double s, double t) { return s>t || s<-t; }
}