Mercurial > hg > jslab
view src/samer/models/Mixture.java @ 5:b67a33c44de7
Remove some crap, etc
author | samer |
---|---|
date | Fri, 05 Apr 2019 21:34:25 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2002, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.types.*; import samer.maths.*; import samer.maths.opt.*; import samer.tools.*; public class Mixture extends NamedTask implements Model { private Model M[]; // models private int n, m; // size of vector, num models private Vec x; // input private VVector w; // prior weights private VVector s; // posterior private int k; // MAP estimate private VDouble Z; // Parition function, ie p(x) private double[] _x,_s,_w,_g; public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); } public Mixture( int N, int L) { super("mixture"); Shell.push(node); n = N; m = L; x = null; w = new VVector("prior",m); s = new VVector("posterior",m); Z = new VDouble("Z"); M = new Model[m]; Shell.pop(); _s=s.array(); _w=w.array(); _g=new double[n]; Mathx.set(_w,new Constant(1.0/L)); } public VVector prior() { return w; } public VVector posterior() { return s; } public void setModel(int i, Model m) { M[i]=m; } public void setInput(Vec in) { x=in; _x=x.array(); } public int getSize() { return n; } public void dispose() { s.dispose(); w.dispose(); Z.dispose(); for (int i=0; i<m; i++) M[i].dispose(); super.dispose(); } public void infer() { // get models to compute energies. // for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); } // compute relative posterior for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies double Emin=Mathx.min(_s); for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]); // compute partition function, normalise posterior Z.set(Mathx.sum(_s)); // compute parition function k=Mathx.argmax(_s); // get MAP model Mathx.mul(_s,1/Z.value); // normalise posterior s.changed(); } public void compute() { /* compute gradients weighted by posterior */ Mathx.zero(_g); for (int i=0; i<m; i++) { double [] phi = M[i].getGradient(); for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j]; } } public double getEnergy() { return -Math.log(Z.value); } public double [] getGradient() { return _g; } public Functionx functionx() { return null; } public void run() { infer(); } public Trainer getTrainer() { return new Trainer(); } public class Trainer implements Model.Trainer { Model.Trainer T[]; VDouble rate; VVector dw; double batch, _dw[]; public Trainer() { T=new Model.Trainer[m]; // should all be null rate=new VDouble("rate",0.001); dw=new VVector("dw",m); _dw=dw.array(); } public void setTrainer(int i,Model.Trainer t) { T[i]=t; } public void dispose() { rate.dispose(); dw.dispose(); } public void accumulate() { accumulate(1.0); } public void accumulate(double w) { batch+=w; for (int i=0;i<m; i++) { if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet } // now accumulate info about priors Mathx.add(_dw,_s); } public void oneshot() { accumulate(1.0); flush(); } public void flush() { for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush(); double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w); double nu=rate.value/batch; dw.changed(); for (int i=0; i<m; i++) { _w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w } Mathx.zero(_dw); batch=0; // normalise Mathx.mul(_w,1/Mathx.sum(_w)); w.changed(); } public void reset() { for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset(); Mathx.zero(_dw); batch=0; } } }