samer@0: package samer.models; samer@0: samer@0: import samer.core.*; samer@0: import samer.core.types.*; samer@0: import samer.maths.*; samer@0: import samer.tools.*; samer@0: import java.util.*; samer@0: samer@0: public class MOGVector extends NamedTask samer@0: { samer@0: Vec input; samer@0: samer@0: int N, M; // num states, number of inputs samer@0: VVector s; // states samer@0: VVector l; // likelihoods samer@0: VDouble L; // total likelihood samer@0: samer@0: Matrix t; // deviations: t=(x-mu)/sig samer@0: Matrix post; // posterior over states samer@0: Matrix w, mu, sig; // weights, means, stddevs samer@0: Matrix dw, dmu, dsig; // accumulated stats for learning samer@0: samer@0: VDouble nuw, numu, nusig; // learning rates samer@0: samer@0: samer@0: // ------------ private bits ------------------------ samer@0: samer@0: private final static double Q=0.5*Math.log(2*Math.PI); samer@0: private double[] tmp; samer@0: private double [] _s; samer@0: private double [] _l; samer@0: private double [][] _t; samer@0: private double [][] _w; samer@0: private double [][] _mu; samer@0: private double [][] _sig; samer@0: private double [][] _post; samer@0: // --------------------------------------------------- samer@0: samer@0: public MOGVector(Vec in, int n) samer@0: { samer@0: super("mog"); samer@0: input=in; samer@0: samer@0: Shell.push(node); samer@0: N=n; M=in.size(); samer@0: samer@0: s = new VVector("state",M); samer@0: l = new VVector("-log p(s)",M); samer@0: L = new VDouble("likelihood"); samer@0: samer@0: t = new Matrix("t",N,M); samer@0: post = new Matrix("p(s|x)",N,M); samer@0: mu = new Matrix("means",N,M); samer@0: sig = new Matrix("sigmas",N,M); samer@0: w = new Matrix("weight",N,M); samer@0: dmu = new Matrix("dmu",N,M); samer@0: dsig = new Matrix("dsig",N,M); samer@0: dw = new Matrix("dw",N,M); samer@0: samer@0: tmp=new double[M]; samer@0: samer@0: Shell.pop(); samer@0: samer@0: // initialise parameters samer@0: samer@0: // weights... samer@0: w.set(new Constant(1.0/N)); samer@0: w.changed(); samer@0: samer@0: // means... samer@0: for (int i=0; ipmax) { state=i; pmax=_post[i][j]; } samer@0: } samer@0: _s[j]=state; samer@0: } samer@0: samer@0: for (int j=0; j