Mercurial > hg > jslab
diff src/samer/models/MOGVector.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/samer/models/MOGVector.java Tue Jan 17 17:50:20 2012 +0000 @@ -0,0 +1,284 @@ +package samer.models; + +import samer.core.*; +import samer.core.types.*; +import samer.maths.*; +import samer.tools.*; +import java.util.*; + +public class MOGVector extends NamedTask +{ + Vec input; + + int N, M; // num states, number of inputs + VVector s; // states + VVector l; // likelihoods + VDouble L; // total likelihood + + Matrix t; // deviations: t=(x-mu)/sig + Matrix post; // posterior over states + Matrix w, mu, sig; // weights, means, stddevs + Matrix dw, dmu, dsig; // accumulated stats for learning + + VDouble nuw, numu, nusig; // learning rates + + + // ------------ private bits ------------------------ + + private final static double Q=0.5*Math.log(2*Math.PI); + private double[] tmp; + private double [] _s; + private double [] _l; + private double [][] _t; + private double [][] _w; + private double [][] _mu; + private double [][] _sig; + private double [][] _post; + // --------------------------------------------------- + + public MOGVector(Vec in, int n) + { + super("mog"); + input=in; + + Shell.push(node); + N=n; M=in.size(); + + s = new VVector("state",M); + l = new VVector("-log p(s)",M); + L = new VDouble("likelihood"); + + t = new Matrix("t",N,M); + post = new Matrix("p(s|x)",N,M); + mu = new Matrix("means",N,M); + sig = new Matrix("sigmas",N,M); + w = new Matrix("weight",N,M); + dmu = new Matrix("dmu",N,M); + dsig = new Matrix("dsig",N,M); + dw = new Matrix("dw",N,M); + + tmp=new double[M]; + + Shell.pop(); + + // initialise parameters + + // weights... + w.set(new Constant(1.0/N)); + w.changed(); + + // means... + for (int i=0; i<N; i++) { + mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i)); + } + mu.changed(); + + // sigmas + sig.set(new Constant(1)); + sig.changed(); + + + // ----- initialise private bits ---------- + _s=s.array(); + _l=l.array(); + _t=t.getArray(); + _w=w.getArray(); + _mu=mu.getArray(); + _sig=sig.getArray(); + _post=post.getArray(); + } + + // normalise column sum of w, return array of sums + private void normalise(double [][] w, double[] sum) { + // make sure weights are normalised properly + Mathx.zero(sum); + for (int i=0; i<N; i++) Mathx.add(sum,w[i]); + for (int i=0; i<N; i++) Mathx.div(w[i],sum); + } + + public VFunction getPDF() { + PDF fn=new PDF(); + VFunction vfn=new VFunction("pdf",fn); + fn.setupObservers(vfn); + return vfn; + } + + public void run() + { + // matlab equivalent code: + // t = (x - mu)./sig + // post = (w./sig).*exp(-0.5*t.^2) + // [dummy, s] = max(post,2) + // Z = sum(post,2) + // post = post./(Z*ones(1,N)) + // Z=Z*Q; + + { + // Vec.Iterator it=input.iterator(); + + double [] x=input.array(); + + for (int i=0; i<N; i++) { + for (int j=0; j<M; j++) { + _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j]; + _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]); + } + } + + // this computes partition function + // and normalises posterior in one go + normalise(_post,_l); + + // get MAP state + for (int j=0; j<M; j++) { + int state=0; + double pmax=_post[0][j]; + + for (int i=1; i<N; i++) { + if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; } + } + _s[j]=state; + } + + for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]); + + L.value = Mathx.sum(_l); + } + + L.changed(); + t.changed(); + s.changed(); + l.changed(); + post.changed(); + } + + public Task learnTask() { + return new AnonymousTask() { + // buffer changes to parameters + double [][] _dw=dw.getArray(); + double [][] _dmu=dmu.getArray(); + double [][] _dsig=dsig.getArray(); + + public void starting() { + dw.zero(); + dmu.zero(); + dsig.zero(); + normalise(w.getArray(),tmp); + } + + public void run() { + for (int i=0; i<N; i++) { + for (int j=0; j<M; j++) { + double pp=_post[i][j], tt=_t[i][j]; + _dw[i][j] += pp; + _dmu[i][j] += pp*tt; + _dsig[i][j]+= pp*(tt*tt-1); + } + } + } + }; + } + + public Task flushTask() { + Shell.push(node); + + try { + return new AnonymousTask() { + + VDouble nuw=new VDouble("weights.learn.rate",.001); + VDouble numu=new VDouble("means.learn.rate",.001); + VDouble nusig=new VDouble("sigmas.learn.rate",.001); + + double [][] _w=w.getArray(); + double [][] _dw=dw.getArray(); + + public void run() + { + // mu += numu*dmu.*sig + // sig += nusig*dsig.*sig; + // lambda = sum(w.*dw,2)./sum(w.*w,2) + // dw -= (lambda*ones(1,N)) .* w + // w *= exp(nu*dw) + + dmu.arrayTimesEquals(sig); + dmu.timesEquals(nusig.value); + mu.plusEquals(dmu); + mu.changed(); + dmu.zero(); + + dsig.arrayTimesEquals(sig); + dsig.timesEquals(nusig.value); + sig.plusEquals(dsig); + sig.changed(); + dsig.zero(); + + { + // the effect of this is to project + // dw away from w. the resulting vector + // is then added to the log of w + + double nu=nuw.value; + + for (int j=0; j<M; j++) { + double w2=0, lambda=0; + for (int i=0; i<N; i++) { + lambda += _w[i][j]*_dw[i][j]; + w2 += _w[i][j]*_w[i][j]; + } + tmp[j]=lambda/w2; + } + + for (int i=0; i<N; i++) { + for (int j=0; j<M; j++) { + _w[i][j] *= Math.exp( + nu*(_dw[i][j] - tmp[j]*_w[i][j]) + ); // update w + } + } + } + + normalise(w.getArray(),tmp); + w.changed(); + dw.zero(); + } + }; + } finally { Shell.pop(); } + } + + class PDF extends Function implements Observer { + VInteger index; + Viewable vbl=null; + + public PDF() { + index=new VInteger("index",M/2); + index.setRange(0,M-1); + index.addObserver(this); + } + + public void dispose() { index.dispose(); } + + public double apply(double x) { + double Z=0; + for (int i=0; i<N; i++) { + int j=index.value; + double t=(x-_mu[i][j])/_sig[i][j]; + Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t); + } + return Z; + } + public String format(String x) { + return "mogpdf("+x+")"; + } + + public void setupObservers(Viewable v) { + w.addObserver(this); + mu.addObserver(this); + sig.addObserver(this); + vbl=v; + } + + public void update(Observable o, Object arg) { + if (arg!=Viewable.DISPOSING) vbl.changed(); + } + } +}