view src/samer/models/ICAWithScaler.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2002, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.Agent.*;
import  samer.core.types.*;
import  samer.tools.*;
import  samer.maths.*;

public class ICAWithScaler extends ICA
{
	VVector  		k;			// state, scaling vector
	double[]		__k, tmp;

	public ICAWithScaler(Vec input) { this(input.size()); setInput(input); }
	public ICAWithScaler(int N)
	{
		super(N);
		Shell.push(node);
		k = new VVector("k",n);
		Shell.pop();
		__k=k.array();
		tmp=new double[n];

		for (int i=0; i<n; i++) __k[i]=1;
		k.changed();
	}

	public void dispose() { k.dispose(); super.dispose(); }
	public void infer() {	// this overrides ICA.infer
		infer.run(); // compute s=Wx
		Mathx.mul(s.array(),__k);
		s.changed();
	}

	public void compute() {
		Mathx.mul(tmp,sourceModel.getGradient(),__k);
		grad.apply(tmp,_g);
	}

	public samer.maths.opt.Functionx functionx() { return null; }

	public void fold() {
		for (int i=0; i<n; i++) {
			Mathx.mul(W.getArray()[i],__k[i]);
			__k[i]=1;
		}
		k.changed();
		W.changed();
	}

	public Trainer getDiffTrainer() { return new DifferentialTrainer(); }
	public Trainer getScaleTrainer() { return new ScalerTrainer(); }


	public class DifferentialTrainer extends ON3Trainer {
		VDouble		scaleRate,stretchRate;
		double		lastflush;

		public DifferentialTrainer() {
			scaleRate=new VDouble("alpha",0.1);
			stretchRate=new VDouble("beta",0.05);
		}

		public void reset() { super.reset(); lastflush=0; }
		public void flush() { diffFlush(); fold(); super.flush(); lastflush=0; }
		public void diffFlush() // flush multipliers to k instead of W
		{
			double batchlet=batch-lastflush;
			if (batchlet==0) return;

			// do differential learning on trace & diagonal of G
			double alpha=scaleRate.value/batchlet;
			double beta=stretchRate.value/batchlet;

			// compute factors and scale each row of W
			double mu=G.trace()/n,dl;
			for (int i=0; i<_n; i++) {
				dl=alpha*mu+beta*(_G[i][i]-mu);
				double tmp=Math.exp(-dl);
				// if (Double.isNaN(tmp)) throw new Error("alt: NaN"+i);
				__k[i]*=tmp; // instead of Mathx.mul(_W[i],tmp);
				_G[i][i]=0;
			}
			k.changed();
			lastflush=batch;
		}
	}

	/** This one trains ONLY the scaler part, not the ICA part, so is a lot faster
		than using the differential trainer with a zero learning rate. */
	
	public class ScalerTrainer extends AnonymousTask implements Model.Trainer
	{
		VVector		G;
		double[]		_G;
		double[]     _g,_s;
		VDouble		scaleRate,stretchRate;
		int				_n;
		double		batch,thresh;
		
		public ScalerTrainer()
		{
			_n=n;
			G=new VVector("G",n);
			thresh=Shell.getDouble("anomaly",20*n);
			scaleRate=new VDouble("alpha",0.02);
			stretchRate=new VDouble("beta",0.002);
			batch=0;

			_s=s.array();
			_G=G.array();
		}

		public void starting() { reset(); }
		public void run() { accumulate(); }

		public void dispose() { G.dispose(); scaleRate.dispose(); stretchRate.dispose(); super.dispose(); }
		public void oneshot() { accumulate(); flush(); }
		public void reset() { Mathx.zero(_G); batch=0; }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w) {
			// HACK!
			if (sourceModel.getEnergy()>thresh) return; 
			batch+=w;

			double[] phi=sourceModel.getGradient();
			for (int i=0; i<_n; i++) _G[i] += w*(phi[i]*_s[i] - 1);
		}

		public void flush()
		{
			if (batch==0) return;

			G.changed();
			
			// do differential learning on trace & diagonal of G
			double alpha=scaleRate.value/batch;
			double beta=stretchRate.value/batch;
			
			// compute factors and scale each row of W
			double mu=Mathx.sum(_G)/n, dl;
			for (int i=0; i<_n; i++) {
				dl=alpha*mu+beta*(_G[i]-mu);
				double tmp=Math.exp(-dl);
				if (Double.isNaN(tmp)) throw new Error("alt: NaN"+i);
				__k[i]*=tmp; // instead of Mathx.mul(_W[i],tmp);
			}
			k.changed();
		
			reset(); // ready for next batch
		}
	}
}