Mercurial > hg > jslab
view src/samer/models/MOGModel.java @ 5:b67a33c44de7
Remove some crap, etc
author | samer |
---|---|
date | Fri, 05 Apr 2019 21:34:25 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
package samer.models; import samer.core.*; import samer.core.types.*; import samer.maths.*; import samer.tools.*; public class MOGModel extends NamedTask { Generator input; int N, state; // num states VDouble Z; VVector vt, vpost; // deviations, posterior VVector vw, vmu, vsig; // weights, means, and stddevs double[] t, post; // deviations, posterior double[] w, mu, sig; // weights, means, and stddevs double[] dw, dmu, dsig; // deltas VDouble nuw, numu, nusig; // learning rates boolean lrnW, lrnMu, lrnSig; final static double Q=1/Math.sqrt(2*Math.PI);; public MOGModel(Generator in, int n) { super("mog"); input=in; Shell.push(node); N=n; Z = new VDouble("likelihood"); t = new double[N]; post = new double[N]; w = new double[N]; mu = new double[N]; sig = new double[N]; dw = new double[N]; dmu = new double[N]; dsig = new double[N]; vpost=new VVector("p",post); vw=new VVector("weights",w); vmu=new VVector("means",mu); vsig=new VVector("sigmas",sig); lrnW=Shell.getBoolean("weights.learn",true); lrnMu=Shell.getBoolean("means.learn",true); lrnSig=Shell.getBoolean("sigmas.learn",true); nuw=new VDouble("weights.learn.rate",.001); numu=new VDouble("means.learn.rate",.001); nusig=new VDouble("sigmas.learn.rate",.001); // nu=new VDouble("learn.rate",0.001); Shell.pop(); // initialise parameters Generator rnd=new samer.maths.random.NormalisedGaussian(); // weights... Mathx.zero(w); Mathx.add(w,1.0/N); vw.changed(); // means... Mathx.set(mu,rnd); vmu.changed(); // sigmas Mathx.zero(sig); Mathx.add(sig,1); vsig.changed(); } public void run() { data(input.next()); } public void starting() { Mathx.zero(dw); Mathx.zero(dmu); Mathx.zero(dsig); // make sure weights are normalised properly Mathx.mul(w,1/Mathx.sum(w)); } public Task flushTask() { return new AnonymousTask() { public void run() { flush(); } }; } public Function getPDF() { return new Function() { public double apply(double x) { double Z=0; for (int i=0; i<N; i++) { t[i]=(x-mu[i])/sig[i]; Z+=(w[i]/sig[i])*Math.exp(-0.5*t[i]*t[i]); } return Z*Q; } public String format(String x) { return "mogpdf("+x+")"; } }; } public void data(double x) { // find posterior and MAP state state=0; double Z=0; for (int i=0; i<N; i++) { t[i]=(x-mu[i])/sig[i]; post[i] = (w[i]/sig[i])*Math.exp(-0.5*t[i]*t[i]); Z+=post[i]; if (post[i]>post[state]) state=i; } Mathx.mul(post,1/Z); Z *= Q; if (lrnW||lrnMu||lrnSig) { // buffer changes to parameters for (int i=0; i<N; i++) { double p=post[i], tt=t[i]; if (lrnW) dw[i] += p; if (lrnMu) dmu[i] += p*tt; if (lrnSig) dsig[i]+= p*(tt*tt-1); } } this.Z.value=Z; this.Z.changed(); vpost.changed(); } public void flush() { if (lrnMu) { Mathx.mul(dmu,sig); flush(mu,dmu,numu.value); vmu.changed(); } if (lrnSig) { Mathx.mul(dsig,sig); flush(sig,dsig,nusig.value); vsig.changed(); } if (lrnW) { /* // normalise Mathx.mul(w,1/Mathx.sum(w)); Mathx.div(dw,w); double lambda=Mathx.sum(dw)/N; Mathx.sub(dw,lambda); flush(w,dw,nuw.value); */ double lambda=Mathx.dot(w,dw)/Mathx.dot(w,w); double nu=nuw.value; for (int i=0; i<N; i++) { dw[i] -= lambda*w[i]; // project away from w w[i] *= Math.exp(nu*dw[i]); // update w } Mathx.zero(dw); // normalise Mathx.mul(w,1/Mathx.sum(w)); vw.changed(); } } private void flush(double [] theta, double [] dtheta, double nu) { Mathx.mul(dtheta,nu); // *this.nu.value); Mathx.add(theta,dtheta); Mathx.zero(dtheta); } }