Mercurial > hg > jslab
view src/samer/models/MOGVector.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line source
package samer.models; import samer.core.*; import samer.core.types.*; import samer.maths.*; import samer.tools.*; import java.util.*; public class MOGVector extends NamedTask { Vec input; int N, M; // num states, number of inputs VVector s; // states VVector l; // likelihoods VDouble L; // total likelihood Matrix t; // deviations: t=(x-mu)/sig Matrix post; // posterior over states Matrix w, mu, sig; // weights, means, stddevs Matrix dw, dmu, dsig; // accumulated stats for learning VDouble nuw, numu, nusig; // learning rates // ------------ private bits ------------------------ private final static double Q=0.5*Math.log(2*Math.PI); private double[] tmp; private double [] _s; private double [] _l; private double [][] _t; private double [][] _w; private double [][] _mu; private double [][] _sig; private double [][] _post; // --------------------------------------------------- public MOGVector(Vec in, int n) { super("mog"); input=in; Shell.push(node); N=n; M=in.size(); s = new VVector("state",M); l = new VVector("-log p(s)",M); L = new VDouble("likelihood"); t = new Matrix("t",N,M); post = new Matrix("p(s|x)",N,M); mu = new Matrix("means",N,M); sig = new Matrix("sigmas",N,M); w = new Matrix("weight",N,M); dmu = new Matrix("dmu",N,M); dsig = new Matrix("dsig",N,M); dw = new Matrix("dw",N,M); tmp=new double[M]; Shell.pop(); // initialise parameters // weights... w.set(new Constant(1.0/N)); w.changed(); // means... for (int i=0; i<N; i++) { mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i)); } mu.changed(); // sigmas sig.set(new Constant(1)); sig.changed(); // ----- initialise private bits ---------- _s=s.array(); _l=l.array(); _t=t.getArray(); _w=w.getArray(); _mu=mu.getArray(); _sig=sig.getArray(); _post=post.getArray(); } // normalise column sum of w, return array of sums private void normalise(double [][] w, double[] sum) { // make sure weights are normalised properly Mathx.zero(sum); for (int i=0; i<N; i++) Mathx.add(sum,w[i]); for (int i=0; i<N; i++) Mathx.div(w[i],sum); } public VFunction getPDF() { PDF fn=new PDF(); VFunction vfn=new VFunction("pdf",fn); fn.setupObservers(vfn); return vfn; } public void run() { // matlab equivalent code: // t = (x - mu)./sig // post = (w./sig).*exp(-0.5*t.^2) // [dummy, s] = max(post,2) // Z = sum(post,2) // post = post./(Z*ones(1,N)) // Z=Z*Q; { // Vec.Iterator it=input.iterator(); double [] x=input.array(); for (int i=0; i<N; i++) { for (int j=0; j<M; j++) { _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j]; _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]); } } // this computes partition function // and normalises posterior in one go normalise(_post,_l); // get MAP state for (int j=0; j<M; j++) { int state=0; double pmax=_post[0][j]; for (int i=1; i<N; i++) { if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; } } _s[j]=state; } for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]); L.value = Mathx.sum(_l); } L.changed(); t.changed(); s.changed(); l.changed(); post.changed(); } public Task learnTask() { return new AnonymousTask() { // buffer changes to parameters double [][] _dw=dw.getArray(); double [][] _dmu=dmu.getArray(); double [][] _dsig=dsig.getArray(); public void starting() { dw.zero(); dmu.zero(); dsig.zero(); normalise(w.getArray(),tmp); } public void run() { for (int i=0; i<N; i++) { for (int j=0; j<M; j++) { double pp=_post[i][j], tt=_t[i][j]; _dw[i][j] += pp; _dmu[i][j] += pp*tt; _dsig[i][j]+= pp*(tt*tt-1); } } } }; } public Task flushTask() { Shell.push(node); try { return new AnonymousTask() { VDouble nuw=new VDouble("weights.learn.rate",.001); VDouble numu=new VDouble("means.learn.rate",.001); VDouble nusig=new VDouble("sigmas.learn.rate",.001); double [][] _w=w.getArray(); double [][] _dw=dw.getArray(); public void run() { // mu += numu*dmu.*sig // sig += nusig*dsig.*sig; // lambda = sum(w.*dw,2)./sum(w.*w,2) // dw -= (lambda*ones(1,N)) .* w // w *= exp(nu*dw) dmu.arrayTimesEquals(sig); dmu.timesEquals(nusig.value); mu.plusEquals(dmu); mu.changed(); dmu.zero(); dsig.arrayTimesEquals(sig); dsig.timesEquals(nusig.value); sig.plusEquals(dsig); sig.changed(); dsig.zero(); { // the effect of this is to project // dw away from w. the resulting vector // is then added to the log of w double nu=nuw.value; for (int j=0; j<M; j++) { double w2=0, lambda=0; for (int i=0; i<N; i++) { lambda += _w[i][j]*_dw[i][j]; w2 += _w[i][j]*_w[i][j]; } tmp[j]=lambda/w2; } for (int i=0; i<N; i++) { for (int j=0; j<M; j++) { _w[i][j] *= Math.exp( nu*(_dw[i][j] - tmp[j]*_w[i][j]) ); // update w } } } normalise(w.getArray(),tmp); w.changed(); dw.zero(); } }; } finally { Shell.pop(); } } class PDF extends Function implements Observer { VInteger index; Viewable vbl=null; public PDF() { index=new VInteger("index",M/2); index.setRange(0,M-1); index.addObserver(this); } public void dispose() { index.dispose(); } public double apply(double x) { double Z=0; for (int i=0; i<N; i++) { int j=index.value; double t=(x-_mu[i][j])/_sig[i][j]; Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t); } return Z; } public String format(String x) { return "mogpdf("+x+")"; } public void setupObservers(Viewable v) { w.addObserver(this); mu.addObserver(this); sig.addObserver(this); vbl=v; } public void update(Observable o, Object arg) { if (arg!=Viewable.DISPOSING) vbl.changed(); } } }