annotate src/samer/models/Mixture.java @ 3:15b93db27c04

Get StreamSource to compile, update args for demo
author samer
date Fri, 05 Apr 2019 17:00:18 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2002, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11
samer@0 12 import samer.core.*;
samer@0 13 import samer.core.types.*;
samer@0 14 import samer.maths.*;
samer@0 15 import samer.maths.opt.*;
samer@0 16 import samer.tools.*;
samer@0 17
samer@0 18
samer@0 19 public class Mixture extends NamedTask implements Model
samer@0 20 {
samer@0 21 private Model M[]; // models
samer@0 22 private int n, m; // size of vector, num models
samer@0 23 private Vec x; // input
samer@0 24 private VVector w; // prior weights
samer@0 25 private VVector s; // posterior
samer@0 26 private int k; // MAP estimate
samer@0 27 private VDouble Z; // Parition function, ie p(x)
samer@0 28 private double[] _x,_s,_w,_g;
samer@0 29
samer@0 30 public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); }
samer@0 31 public Mixture( int N, int L)
samer@0 32 {
samer@0 33 super("mixture");
samer@0 34 Shell.push(node);
samer@0 35
samer@0 36 n = N;
samer@0 37 m = L;
samer@0 38
samer@0 39 x = null;
samer@0 40 w = new VVector("prior",m);
samer@0 41 s = new VVector("posterior",m);
samer@0 42 Z = new VDouble("Z");
samer@0 43 M = new Model[m];
samer@0 44 Shell.pop();
samer@0 45
samer@0 46 _s=s.array();
samer@0 47 _w=w.array();
samer@0 48 _g=new double[n];
samer@0 49 Mathx.set(_w,new Constant(1.0/L));
samer@0 50 }
samer@0 51
samer@0 52 public VVector prior() { return w; }
samer@0 53 public VVector posterior() { return s; }
samer@0 54 public void setModel(int i, Model m) { M[i]=m; }
samer@0 55 public void setInput(Vec in) { x=in; _x=x.array(); }
samer@0 56 public int getSize() { return n; }
samer@0 57
samer@0 58 public void dispose()
samer@0 59 {
samer@0 60 s.dispose();
samer@0 61 w.dispose();
samer@0 62 Z.dispose();
samer@0 63 for (int i=0; i<m; i++) M[i].dispose();
samer@0 64 super.dispose();
samer@0 65 }
samer@0 66
samer@0 67 public void infer() {
samer@0 68 // get models to compute energies.
samer@0 69 // for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); }
samer@0 70
samer@0 71 // compute relative posterior
samer@0 72 for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies
samer@0 73 double Emin=Mathx.min(_s);
samer@0 74 for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]);
samer@0 75
samer@0 76 // compute partition function, normalise posterior
samer@0 77 Z.set(Mathx.sum(_s)); // compute parition function
samer@0 78 k=Mathx.argmax(_s); // get MAP model
samer@0 79 Mathx.mul(_s,1/Z.value); // normalise posterior
samer@0 80 s.changed();
samer@0 81 }
samer@0 82
samer@0 83 public void compute()
samer@0 84 {
samer@0 85 /* compute gradients weighted by posterior */
samer@0 86 Mathx.zero(_g);
samer@0 87 for (int i=0; i<m; i++) {
samer@0 88 double [] phi = M[i].getGradient();
samer@0 89 for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j];
samer@0 90 }
samer@0 91 }
samer@0 92
samer@0 93 public double getEnergy() { return -Math.log(Z.value); }
samer@0 94 public double [] getGradient() { return _g; }
samer@0 95
samer@0 96 public Functionx functionx() { return null; }
samer@0 97
samer@0 98 public void run() { infer(); }
samer@0 99
samer@0 100 public Trainer getTrainer() { return new Trainer(); }
samer@0 101
samer@0 102 public class Trainer implements Model.Trainer
samer@0 103 {
samer@0 104 Model.Trainer T[];
samer@0 105 VDouble rate;
samer@0 106 VVector dw;
samer@0 107 double batch, _dw[];
samer@0 108
samer@0 109 public Trainer() {
samer@0 110 T=new Model.Trainer[m]; // should all be null
samer@0 111 rate=new VDouble("rate",0.001);
samer@0 112 dw=new VVector("dw",m);
samer@0 113 _dw=dw.array();
samer@0 114 }
samer@0 115
samer@0 116 public void setTrainer(int i,Model.Trainer t) { T[i]=t; }
samer@0 117 public void dispose() { rate.dispose(); dw.dispose(); }
samer@0 118
samer@0 119 public void accumulate() { accumulate(1.0); }
samer@0 120 public void accumulate(double w) {
samer@0 121 batch+=w;
samer@0 122 for (int i=0;i<m; i++) {
samer@0 123 if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet
samer@0 124 }
samer@0 125
samer@0 126 // now accumulate info about priors
samer@0 127 Mathx.add(_dw,_s);
samer@0 128 }
samer@0 129
samer@0 130 public void oneshot() { accumulate(1.0); flush(); }
samer@0 131 public void flush() {
samer@0 132 for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush();
samer@0 133 double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w);
samer@0 134 double nu=rate.value/batch;
samer@0 135
samer@0 136 dw.changed();
samer@0 137 for (int i=0; i<m; i++) {
samer@0 138 _w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w
samer@0 139 }
samer@0 140 Mathx.zero(_dw); batch=0;
samer@0 141
samer@0 142 // normalise
samer@0 143 Mathx.mul(_w,1/Mathx.sum(_w));
samer@0 144 w.changed();
samer@0 145 }
samer@0 146 public void reset() {
samer@0 147 for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset();
samer@0 148 Mathx.zero(_dw);
samer@0 149 batch=0;
samer@0 150 }
samer@0 151 }
samer@0 152 }
samer@0 153