samer@0
|
1 package samer.models;
|
samer@0
|
2
|
samer@0
|
3 import samer.maths.*;
|
samer@0
|
4 import samer.maths.opt.*;
|
samer@0
|
5 import samer.tools.*;
|
samer@0
|
6 import samer.core.*;
|
samer@0
|
7 import samer.core.types.*;
|
samer@0
|
8
|
samer@0
|
9
|
samer@0
|
10 public class GeneralisedExponential extends NullTask implements Model {
|
samer@0
|
11 Vec input;
|
samer@0
|
12 VVector alpha, e, grad;
|
samer@0
|
13 int N;
|
samer@0
|
14 VDouble E;
|
samer@0
|
15 double[] x, g, e0, _e, a;
|
samer@0
|
16 double L0;
|
samer@0
|
17
|
samer@0
|
18 public GeneralisedExponential(Vec x) { this(x.size()); setInput(x); }
|
samer@0
|
19 public GeneralisedExponential(int n) {
|
samer@0
|
20 N=n;
|
samer@0
|
21 E=new VDouble("E");
|
samer@0
|
22 e=new VVector("e",N);
|
samer@0
|
23 grad=new VVector("phi",N);
|
samer@0
|
24 alpha=new VVector("alpha",N);
|
samer@0
|
25 alpha.addSaver();
|
samer@0
|
26
|
samer@0
|
27 g=grad.array(); // new double[N];
|
samer@0
|
28 e0=new double[N];
|
samer@0
|
29 a=alpha.array();
|
samer@0
|
30 _e=e.array();
|
samer@0
|
31
|
samer@0
|
32 Mathx.setAll(a,1.0); L0=0;
|
samer@0
|
33 }
|
samer@0
|
34
|
samer@0
|
35 public String toString() { return "GeneralisedExponential("+input+")"; }
|
samer@0
|
36 public void setInput(Vec in) { input=in; x=input.array(); }
|
samer@0
|
37 public int getSize() { return N; }
|
samer@0
|
38 public void dispose() {
|
samer@0
|
39 alpha.dispose();
|
samer@0
|
40 grad.dispose();
|
samer@0
|
41 E.dispose();
|
samer@0
|
42 e.dispose();
|
samer@0
|
43 }
|
samer@0
|
44
|
samer@0
|
45 public VVector getEnergyVector() { return e; }
|
samer@0
|
46 public VDouble getEnergySignal() { return E; }
|
samer@0
|
47 public double getEnergy() { return E.value; }
|
samer@0
|
48 public double [] getGradient() { return g; }
|
samer@0
|
49 public VVector getAlphas() { return alpha; }
|
samer@0
|
50
|
samer@0
|
51 public void run() { compute(); }
|
samer@0
|
52 public void infer() {}
|
samer@0
|
53 public void compute() {
|
samer@0
|
54
|
samer@0
|
55 // compute log likelihood
|
samer@0
|
56
|
samer@0
|
57 for (int i=0; i<N; i++) _e[i] = Math.pow(Math.abs(x[i]),a[i]);
|
samer@0
|
58
|
samer@0
|
59 // compute gradient g_i = dL/dx_i
|
samer@0
|
60 for (int i=0; i<N; i++) {
|
samer@0
|
61 if (x[i]==0) g[i]=0;
|
samer@0
|
62 else g[i] = a[i]*(_e[i]/x[i]);
|
samer@0
|
63 }
|
samer@0
|
64
|
samer@0
|
65 e.changed();
|
samer@0
|
66 grad.changed();
|
samer@0
|
67 E.set(Mathx.sum(_e)+L0);
|
samer@0
|
68 }
|
samer@0
|
69
|
samer@0
|
70 private void precompute() {
|
samer@0
|
71 // this computes the x-independent part of log p(x), ie fn of alpha
|
samer@0
|
72 for (int i=0; i<N; i++) e0[i]= Math.log((2/a[i]))+logGamma(1/a[i]);
|
samer@0
|
73 L0=Mathx.sum(e0);
|
samer@0
|
74 }
|
samer@0
|
75
|
samer@0
|
76 public Functionx functionx() {
|
samer@0
|
77 return new Functionx() {
|
samer@0
|
78 double [] __e=new double[N];
|
samer@0
|
79 public void dispose() {}
|
samer@0
|
80 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
|
samer@0
|
81 public double evaluate(double [] x, double [] g) {
|
samer@0
|
82 for (int i=0; i<N; i++) {
|
samer@0
|
83 if (x[i]==0) { g[i]=0; __e[i]=0; }
|
samer@0
|
84 else {
|
samer@0
|
85 __e[i] = Math.pow(Math.abs(x[i]),a[i]);
|
samer@0
|
86 g[i] = a[i]*(__e[i]/x[i]);
|
samer@0
|
87 }
|
samer@0
|
88 }
|
samer@0
|
89 return Mathx.sum(__e)+L0;
|
samer@0
|
90 }
|
samer@0
|
91 };
|
samer@0
|
92 }
|
samer@0
|
93
|
samer@0
|
94 public Trainer getTrainer() { return new Trainer(); }
|
samer@0
|
95
|
samer@0
|
96 public class Trainer extends AnonymousTask implements Model.Trainer {
|
samer@0
|
97 VDouble rate; // learning rate
|
samer@0
|
98 double[] A; // statistics
|
samer@0
|
99 double count;
|
samer@0
|
100
|
samer@0
|
101 // estimation:
|
samer@0
|
102 // 1/beta = alpha*avg(abs(x^alpha));
|
samer@0
|
103
|
samer@0
|
104 public Trainer() {
|
samer@0
|
105 rate=new VDouble("rate",0.001);
|
samer@0
|
106 A=new double[N];
|
samer@0
|
107 }
|
samer@0
|
108
|
samer@0
|
109 public String toString() { return "Trainer:"+GeneralisedExponential.this; }
|
samer@0
|
110 public VDouble getRate() { return rate; }
|
samer@0
|
111
|
samer@0
|
112 public void reset() { Mathx.zero(A); count=0; }
|
samer@0
|
113 public void accumulate() { accumulate(1); }
|
samer@0
|
114 public void accumulate(double w) {
|
samer@0
|
115 for (int i=0; i<N; i++) {
|
samer@0
|
116 if (x[i]!=0) A[i] -= w*_e[i]*Math.log(Math.abs(x[i]));
|
samer@0
|
117 }
|
samer@0
|
118 count+=w;
|
samer@0
|
119 }
|
samer@0
|
120
|
samer@0
|
121 public void flush() {
|
samer@0
|
122 if (count==0) return;
|
samer@0
|
123 double eta=rate.value;
|
samer@0
|
124
|
samer@0
|
125 for (int i=0; i<N; i++) {
|
samer@0
|
126 double ra=1/a[i];
|
samer@0
|
127 a[i] += eta*(A[i]/count + ra*(1+ra*digamma(ra)));
|
samer@0
|
128 // make sure a[i] not negative:
|
samer@0
|
129 if (a[i]<=0) a[i]=0.01; // a small value
|
samer@0
|
130
|
samer@0
|
131 // precompute for next round
|
samer@0
|
132 e0[i]= Math.log(2*ra)+logGamma(ra);
|
samer@0
|
133 }
|
samer@0
|
134 L0=Mathx.sum(e0);
|
samer@0
|
135 alpha.changed();
|
samer@0
|
136 reset();
|
samer@0
|
137 }
|
samer@0
|
138
|
samer@0
|
139 public void oneshot() { reset(); accumulate(1); flush(); }
|
samer@0
|
140 public void dispose() { rate.dispose(); }
|
samer@0
|
141 public void starting() { reset(); }
|
samer@0
|
142 public void run() { accumulate(1); flush(); }
|
samer@0
|
143
|
samer@0
|
144 }
|
samer@0
|
145
|
samer@0
|
146 private static double S=1e-5, C=8.5,
|
samer@0
|
147 S3=8.33333333333333333333E-2,
|
samer@0
|
148 S4=8.33333333333333333333E-3,
|
samer@0
|
149 S5=3.96825396825396825397E-3,
|
samer@0
|
150 D1=-0.5772156649;
|
samer@0
|
151
|
samer@0
|
152 public static double digamma(double t) {
|
samer@0
|
153 double z=0;
|
samer@0
|
154 if (t<S) return D1-1/t;
|
samer@0
|
155 for (z=0; t<C; t++) z-=1/t;
|
samer@0
|
156 double r=1/t;
|
samer@0
|
157 z+=Math.log(t) - 0.5*r;
|
samer@0
|
158 r*=r;
|
samer@0
|
159 z-=r*(S3 - r*(S4 - r*S5));
|
samer@0
|
160 return z;
|
samer@0
|
161 }
|
samer@0
|
162
|
samer@0
|
163 private static double gamma(double t) {
|
samer@0
|
164 return samer.functions.Gamma.gamma(t);
|
samer@0
|
165 // return edu.uah.math.distributions.Functions.gamma(t);
|
samer@0
|
166 }
|
samer@0
|
167 private static double logGamma(double t) {
|
samer@0
|
168 return samer.functions.Gamma.logOfGamma(t);
|
samer@0
|
169 // return edu.uah.math.distributions.Functions.logGamma(t);
|
samer@0
|
170 }
|
samer@0
|
171 }
|