Mercurial > hg > jslab
diff src/samer/models/Mixture.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/samer/models/Mixture.java Tue Jan 17 17:50:20 2012 +0000 @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2002, Samer Abdallah, King's College London. + * All rights reserved. + * + * This software is provided AS iS and WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. + */ + +package samer.models; + +import samer.core.*; +import samer.core.types.*; +import samer.maths.*; +import samer.maths.opt.*; +import samer.tools.*; + + +public class Mixture extends NamedTask implements Model +{ + private Model M[]; // models + private int n, m; // size of vector, num models + private Vec x; // input + private VVector w; // prior weights + private VVector s; // posterior + private int k; // MAP estimate + private VDouble Z; // Parition function, ie p(x) + private double[] _x,_s,_w,_g; + + public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); } + public Mixture( int N, int L) + { + super("mixture"); + Shell.push(node); + + n = N; + m = L; + + x = null; + w = new VVector("prior",m); + s = new VVector("posterior",m); + Z = new VDouble("Z"); + M = new Model[m]; + Shell.pop(); + + _s=s.array(); + _w=w.array(); + _g=new double[n]; + Mathx.set(_w,new Constant(1.0/L)); + } + + public VVector prior() { return w; } + public VVector posterior() { return s; } + public void setModel(int i, Model m) { M[i]=m; } + public void setInput(Vec in) { x=in; _x=x.array(); } + public int getSize() { return n; } + + public void dispose() + { + s.dispose(); + w.dispose(); + Z.dispose(); + for (int i=0; i<m; i++) M[i].dispose(); + super.dispose(); + } + + public void infer() { + // get models to compute energies. + // for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); } + + // compute relative posterior + for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies + double Emin=Mathx.min(_s); + for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]); + + // compute partition function, normalise posterior + Z.set(Mathx.sum(_s)); // compute parition function + k=Mathx.argmax(_s); // get MAP model + Mathx.mul(_s,1/Z.value); // normalise posterior + s.changed(); + } + + public void compute() + { + /* compute gradients weighted by posterior */ + Mathx.zero(_g); + for (int i=0; i<m; i++) { + double [] phi = M[i].getGradient(); + for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j]; + } + } + + public double getEnergy() { return -Math.log(Z.value); } + public double [] getGradient() { return _g; } + + public Functionx functionx() { return null; } + + public void run() { infer(); } + + public Trainer getTrainer() { return new Trainer(); } + + public class Trainer implements Model.Trainer + { + Model.Trainer T[]; + VDouble rate; + VVector dw; + double batch, _dw[]; + + public Trainer() { + T=new Model.Trainer[m]; // should all be null + rate=new VDouble("rate",0.001); + dw=new VVector("dw",m); + _dw=dw.array(); + } + + public void setTrainer(int i,Model.Trainer t) { T[i]=t; } + public void dispose() { rate.dispose(); dw.dispose(); } + + public void accumulate() { accumulate(1.0); } + public void accumulate(double w) { + batch+=w; + for (int i=0;i<m; i++) { + if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet + } + + // now accumulate info about priors + Mathx.add(_dw,_s); + } + + public void oneshot() { accumulate(1.0); flush(); } + public void flush() { + for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush(); + double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w); + double nu=rate.value/batch; + + dw.changed(); + for (int i=0; i<m; i++) { + _w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w + } + Mathx.zero(_dw); batch=0; + + // normalise + Mathx.mul(_w,1/Mathx.sum(_w)); + w.changed(); + } + public void reset() { + for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset(); + Mathx.zero(_dw); + batch=0; + } + } +} +