annotate src/samer/models/GeneralisedExponential.java @ 3:15b93db27c04

Get StreamSource to compile, update args for demo
author samer
date Fri, 05 Apr 2019 17:00:18 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 package samer.models;
samer@0 2
samer@0 3 import samer.maths.*;
samer@0 4 import samer.maths.opt.*;
samer@0 5 import samer.tools.*;
samer@0 6 import samer.core.*;
samer@0 7 import samer.core.types.*;
samer@0 8
samer@0 9
samer@0 10 public class GeneralisedExponential extends NullTask implements Model {
samer@0 11 Vec input;
samer@0 12 VVector alpha, e, grad;
samer@0 13 int N;
samer@0 14 VDouble E;
samer@0 15 double[] x, g, e0, _e, a;
samer@0 16 double L0;
samer@0 17
samer@0 18 public GeneralisedExponential(Vec x) { this(x.size()); setInput(x); }
samer@0 19 public GeneralisedExponential(int n) {
samer@0 20 N=n;
samer@0 21 E=new VDouble("E");
samer@0 22 e=new VVector("e",N);
samer@0 23 grad=new VVector("phi",N);
samer@0 24 alpha=new VVector("alpha",N);
samer@0 25 alpha.addSaver();
samer@0 26
samer@0 27 g=grad.array(); // new double[N];
samer@0 28 e0=new double[N];
samer@0 29 a=alpha.array();
samer@0 30 _e=e.array();
samer@0 31
samer@0 32 Mathx.setAll(a,1.0); L0=0;
samer@0 33 }
samer@0 34
samer@0 35 public String toString() { return "GeneralisedExponential("+input+")"; }
samer@0 36 public void setInput(Vec in) { input=in; x=input.array(); }
samer@0 37 public int getSize() { return N; }
samer@0 38 public void dispose() {
samer@0 39 alpha.dispose();
samer@0 40 grad.dispose();
samer@0 41 E.dispose();
samer@0 42 e.dispose();
samer@0 43 }
samer@0 44
samer@0 45 public VVector getEnergyVector() { return e; }
samer@0 46 public VDouble getEnergySignal() { return E; }
samer@0 47 public double getEnergy() { return E.value; }
samer@0 48 public double [] getGradient() { return g; }
samer@0 49 public VVector getAlphas() { return alpha; }
samer@0 50
samer@0 51 public void run() { compute(); }
samer@0 52 public void infer() {}
samer@0 53 public void compute() {
samer@0 54
samer@0 55 // compute log likelihood
samer@0 56
samer@0 57 for (int i=0; i<N; i++) _e[i] = Math.pow(Math.abs(x[i]),a[i]);
samer@0 58
samer@0 59 // compute gradient g_i = dL/dx_i
samer@0 60 for (int i=0; i<N; i++) {
samer@0 61 if (x[i]==0) g[i]=0;
samer@0 62 else g[i] = a[i]*(_e[i]/x[i]);
samer@0 63 }
samer@0 64
samer@0 65 e.changed();
samer@0 66 grad.changed();
samer@0 67 E.set(Mathx.sum(_e)+L0);
samer@0 68 }
samer@0 69
samer@0 70 private void precompute() {
samer@0 71 // this computes the x-independent part of log p(x), ie fn of alpha
samer@0 72 for (int i=0; i<N; i++) e0[i]= Math.log((2/a[i]))+logGamma(1/a[i]);
samer@0 73 L0=Mathx.sum(e0);
samer@0 74 }
samer@0 75
samer@0 76 public Functionx functionx() {
samer@0 77 return new Functionx() {
samer@0 78 double [] __e=new double[N];
samer@0 79 public void dispose() {}
samer@0 80 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
samer@0 81 public double evaluate(double [] x, double [] g) {
samer@0 82 for (int i=0; i<N; i++) {
samer@0 83 if (x[i]==0) { g[i]=0; __e[i]=0; }
samer@0 84 else {
samer@0 85 __e[i] = Math.pow(Math.abs(x[i]),a[i]);
samer@0 86 g[i] = a[i]*(__e[i]/x[i]);
samer@0 87 }
samer@0 88 }
samer@0 89 return Mathx.sum(__e)+L0;
samer@0 90 }
samer@0 91 };
samer@0 92 }
samer@0 93
samer@0 94 public Trainer getTrainer() { return new Trainer(); }
samer@0 95
samer@0 96 public class Trainer extends AnonymousTask implements Model.Trainer {
samer@0 97 VDouble rate; // learning rate
samer@0 98 double[] A; // statistics
samer@0 99 double count;
samer@0 100
samer@0 101 // estimation:
samer@0 102 // 1/beta = alpha*avg(abs(x^alpha));
samer@0 103
samer@0 104 public Trainer() {
samer@0 105 rate=new VDouble("rate",0.001);
samer@0 106 A=new double[N];
samer@0 107 }
samer@0 108
samer@0 109 public String toString() { return "Trainer:"+GeneralisedExponential.this; }
samer@0 110 public VDouble getRate() { return rate; }
samer@0 111
samer@0 112 public void reset() { Mathx.zero(A); count=0; }
samer@0 113 public void accumulate() { accumulate(1); }
samer@0 114 public void accumulate(double w) {
samer@0 115 for (int i=0; i<N; i++) {
samer@0 116 if (x[i]!=0) A[i] -= w*_e[i]*Math.log(Math.abs(x[i]));
samer@0 117 }
samer@0 118 count+=w;
samer@0 119 }
samer@0 120
samer@0 121 public void flush() {
samer@0 122 if (count==0) return;
samer@0 123 double eta=rate.value;
samer@0 124
samer@0 125 for (int i=0; i<N; i++) {
samer@0 126 double ra=1/a[i];
samer@0 127 a[i] += eta*(A[i]/count + ra*(1+ra*digamma(ra)));
samer@0 128 // make sure a[i] not negative:
samer@0 129 if (a[i]<=0) a[i]=0.01; // a small value
samer@0 130
samer@0 131 // precompute for next round
samer@0 132 e0[i]= Math.log(2*ra)+logGamma(ra);
samer@0 133 }
samer@0 134 L0=Mathx.sum(e0);
samer@0 135 alpha.changed();
samer@0 136 reset();
samer@0 137 }
samer@0 138
samer@0 139 public void oneshot() { reset(); accumulate(1); flush(); }
samer@0 140 public void dispose() { rate.dispose(); }
samer@0 141 public void starting() { reset(); }
samer@0 142 public void run() { accumulate(1); flush(); }
samer@0 143
samer@0 144 }
samer@0 145
samer@0 146 private static double S=1e-5, C=8.5,
samer@0 147 S3=8.33333333333333333333E-2,
samer@0 148 S4=8.33333333333333333333E-3,
samer@0 149 S5=3.96825396825396825397E-3,
samer@0 150 D1=-0.5772156649;
samer@0 151
samer@0 152 public static double digamma(double t) {
samer@0 153 double z=0;
samer@0 154 if (t<S) return D1-1/t;
samer@0 155 for (z=0; t<C; t++) z-=1/t;
samer@0 156 double r=1/t;
samer@0 157 z+=Math.log(t) - 0.5*r;
samer@0 158 r*=r;
samer@0 159 z-=r*(S3 - r*(S4 - r*S5));
samer@0 160 return z;
samer@0 161 }
samer@0 162
samer@0 163 private static double gamma(double t) {
samer@0 164 return samer.functions.Gamma.gamma(t);
samer@0 165 // return edu.uah.math.distributions.Functions.gamma(t);
samer@0 166 }
samer@0 167 private static double logGamma(double t) {
samer@0 168 return samer.functions.Gamma.logOfGamma(t);
samer@0 169 // return edu.uah.math.distributions.Functions.logGamma(t);
samer@0 170 }
samer@0 171 }