view src/samer/models/GeneralisedExponential.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
package samer.models;

import samer.maths.*;
import samer.maths.opt.*;
import samer.tools.*;
import samer.core.*;
import samer.core.types.*;


public class GeneralisedExponential extends NullTask implements Model {
	Vec		input;
	VVector	alpha, e, grad;
	int			N;
	VDouble	E;
	double[]	x, g, e0, _e, a;
	double	L0;

	public GeneralisedExponential(Vec x) { this(x.size()); setInput(x); }
	public GeneralisedExponential(int n)  {
		N=n;
		E=new VDouble("E");
		e=new VVector("e",N);
		grad=new VVector("phi",N);
		alpha=new VVector("alpha",N);
		alpha.addSaver();

		g=grad.array(); // new double[N];
		e0=new double[N];
		a=alpha.array();
		_e=e.array();

		Mathx.setAll(a,1.0); L0=0;
	}

	public String toString() { return "GeneralisedExponential("+input+")"; }
	public void setInput(Vec in) { input=in; x=input.array(); }
	public int getSize() { return N; }
	public void dispose() {
		alpha.dispose();
		grad.dispose();
		E.dispose();
		e.dispose();
	}

	public VVector	getEnergyVector() { return e; }
	public VDouble	getEnergySignal() { return E; }
	public double	getEnergy() { return E.value; }
	public double [] getGradient() { return g; 	}
	public VVector  getAlphas() { return alpha; }

	public void run() { compute(); }
	public void infer() {}
	public void compute() {

		// compute log likelihood

		for (int i=0; i<N; i++) _e[i] = Math.pow(Math.abs(x[i]),a[i]);

		// compute gradient g_i = dL/dx_i
		for (int i=0; i<N; i++) {
			if (x[i]==0) g[i]=0;
			else g[i] = a[i]*(_e[i]/x[i]);
		}

		e.changed();
		grad.changed();
		E.set(Mathx.sum(_e)+L0);
	}

	private void precompute() {
		// this computes the x-independent part of log p(x), ie fn of alpha
		for (int i=0; i<N; i++) e0[i]= Math.log((2/a[i]))+logGamma(1/a[i]);
		L0=Mathx.sum(e0);
	}

	public Functionx functionx() {
		return new Functionx() {
			double [] __e=new double[N];
			public void dispose() {}
			public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
			public double evaluate(double [] x, double [] g) {
				for (int i=0; i<N; i++) {
					if (x[i]==0) { g[i]=0; __e[i]=0; }
					else {
						__e[i] = Math.pow(Math.abs(x[i]),a[i]);
						g[i] = a[i]*(__e[i]/x[i]);
					}
				}
				return  Mathx.sum(__e)+L0;
			}
		};
	}

	public Trainer getTrainer() { return new Trainer(); }

	public class Trainer extends AnonymousTask implements Model.Trainer {
		VDouble	rate;		// learning rate
		double[]	A;		// statistics
		double	count;

		// estimation:
		//	1/beta = alpha*avg(abs(x^alpha));

		public Trainer() {
			rate=new VDouble("rate",0.001);
			A=new double[N];
		}

		public String toString() { return "Trainer:"+GeneralisedExponential.this; }
		public VDouble getRate() { return rate; }
		
		public void reset() { Mathx.zero(A);	count=0; }
		public void accumulate() { accumulate(1); }
		public void accumulate(double w) {
			for (int i=0; i<N; i++) {
				if (x[i]!=0) A[i] -= w*_e[i]*Math.log(Math.abs(x[i]));
			}
			count+=w;
		}

		public void flush() {
			if (count==0) return;
			double eta=rate.value;

			for (int i=0; i<N; i++) {
				double ra=1/a[i];
				a[i] += eta*(A[i]/count + ra*(1+ra*digamma(ra)));
				// make sure a[i] not negative:
				if (a[i]<=0) a[i]=0.01; // a small value

				// precompute for next round
				e0[i]= Math.log(2*ra)+logGamma(ra);
			}
			L0=Mathx.sum(e0);
			alpha.changed();
			reset();
		}

		public void oneshot() { reset(); accumulate(1); flush(); }
		public void dispose() { rate.dispose(); }
		public void starting() { reset(); }
		public void run() { accumulate(1); flush(); }

	}

	private static double S=1e-5, C=8.5,
		S3=8.33333333333333333333E-2,
		S4=8.33333333333333333333E-3,
		S5=3.96825396825396825397E-3,
		D1=-0.5772156649;

	public static double digamma(double t) {
		double z=0;
		if (t<S) return D1-1/t;
		for (z=0; t<C; t++) z-=1/t;
		double r=1/t;
		z+=Math.log(t) - 0.5*r;
		r*=r;
		z-=r*(S3 - r*(S4 - r*S5));
		return z;
	}

	private static double gamma(double t) {
		return samer.functions.Gamma.gamma(t);
		// return edu.uah.math.distributions.Functions.gamma(t);
	}
	private static double logGamma(double t) {
		return samer.functions.Gamma.logOfGamma(t);
		// return edu.uah.math.distributions.Functions.logGamma(t);
	}
}