annotate src/samer/models/MOGVector.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 package samer.models;
samer@0 2
samer@0 3 import samer.core.*;
samer@0 4 import samer.core.types.*;
samer@0 5 import samer.maths.*;
samer@0 6 import samer.tools.*;
samer@0 7 import java.util.*;
samer@0 8
samer@0 9 public class MOGVector extends NamedTask
samer@0 10 {
samer@0 11 Vec input;
samer@0 12
samer@0 13 int N, M; // num states, number of inputs
samer@0 14 VVector s; // states
samer@0 15 VVector l; // likelihoods
samer@0 16 VDouble L; // total likelihood
samer@0 17
samer@0 18 Matrix t; // deviations: t=(x-mu)/sig
samer@0 19 Matrix post; // posterior over states
samer@0 20 Matrix w, mu, sig; // weights, means, stddevs
samer@0 21 Matrix dw, dmu, dsig; // accumulated stats for learning
samer@0 22
samer@0 23 VDouble nuw, numu, nusig; // learning rates
samer@0 24
samer@0 25
samer@0 26 // ------------ private bits ------------------------
samer@0 27
samer@0 28 private final static double Q=0.5*Math.log(2*Math.PI);
samer@0 29 private double[] tmp;
samer@0 30 private double [] _s;
samer@0 31 private double [] _l;
samer@0 32 private double [][] _t;
samer@0 33 private double [][] _w;
samer@0 34 private double [][] _mu;
samer@0 35 private double [][] _sig;
samer@0 36 private double [][] _post;
samer@0 37 // ---------------------------------------------------
samer@0 38
samer@0 39 public MOGVector(Vec in, int n)
samer@0 40 {
samer@0 41 super("mog");
samer@0 42 input=in;
samer@0 43
samer@0 44 Shell.push(node);
samer@0 45 N=n; M=in.size();
samer@0 46
samer@0 47 s = new VVector("state",M);
samer@0 48 l = new VVector("-log p(s)",M);
samer@0 49 L = new VDouble("likelihood");
samer@0 50
samer@0 51 t = new Matrix("t",N,M);
samer@0 52 post = new Matrix("p(s|x)",N,M);
samer@0 53 mu = new Matrix("means",N,M);
samer@0 54 sig = new Matrix("sigmas",N,M);
samer@0 55 w = new Matrix("weight",N,M);
samer@0 56 dmu = new Matrix("dmu",N,M);
samer@0 57 dsig = new Matrix("dsig",N,M);
samer@0 58 dw = new Matrix("dw",N,M);
samer@0 59
samer@0 60 tmp=new double[M];
samer@0 61
samer@0 62 Shell.pop();
samer@0 63
samer@0 64 // initialise parameters
samer@0 65
samer@0 66 // weights...
samer@0 67 w.set(new Constant(1.0/N));
samer@0 68 w.changed();
samer@0 69
samer@0 70 // means...
samer@0 71 for (int i=0; i<N; i++) {
samer@0 72 mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i));
samer@0 73 }
samer@0 74 mu.changed();
samer@0 75
samer@0 76 // sigmas
samer@0 77 sig.set(new Constant(1));
samer@0 78 sig.changed();
samer@0 79
samer@0 80
samer@0 81 // ----- initialise private bits ----------
samer@0 82 _s=s.array();
samer@0 83 _l=l.array();
samer@0 84 _t=t.getArray();
samer@0 85 _w=w.getArray();
samer@0 86 _mu=mu.getArray();
samer@0 87 _sig=sig.getArray();
samer@0 88 _post=post.getArray();
samer@0 89 }
samer@0 90
samer@0 91 // normalise column sum of w, return array of sums
samer@0 92 private void normalise(double [][] w, double[] sum) {
samer@0 93 // make sure weights are normalised properly
samer@0 94 Mathx.zero(sum);
samer@0 95 for (int i=0; i<N; i++) Mathx.add(sum,w[i]);
samer@0 96 for (int i=0; i<N; i++) Mathx.div(w[i],sum);
samer@0 97 }
samer@0 98
samer@0 99 public VFunction getPDF() {
samer@0 100 PDF fn=new PDF();
samer@0 101 VFunction vfn=new VFunction("pdf",fn);
samer@0 102 fn.setupObservers(vfn);
samer@0 103 return vfn;
samer@0 104 }
samer@0 105
samer@0 106 public void run()
samer@0 107 {
samer@0 108 // matlab equivalent code:
samer@0 109 // t = (x - mu)./sig
samer@0 110 // post = (w./sig).*exp(-0.5*t.^2)
samer@0 111 // [dummy, s] = max(post,2)
samer@0 112 // Z = sum(post,2)
samer@0 113 // post = post./(Z*ones(1,N))
samer@0 114 // Z=Z*Q;
samer@0 115
samer@0 116 {
samer@0 117 // Vec.Iterator it=input.iterator();
samer@0 118
samer@0 119 double [] x=input.array();
samer@0 120
samer@0 121 for (int i=0; i<N; i++) {
samer@0 122 for (int j=0; j<M; j++) {
samer@0 123 _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j];
samer@0 124 _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]);
samer@0 125 }
samer@0 126 }
samer@0 127
samer@0 128 // this computes partition function
samer@0 129 // and normalises posterior in one go
samer@0 130 normalise(_post,_l);
samer@0 131
samer@0 132 // get MAP state
samer@0 133 for (int j=0; j<M; j++) {
samer@0 134 int state=0;
samer@0 135 double pmax=_post[0][j];
samer@0 136
samer@0 137 for (int i=1; i<N; i++) {
samer@0 138 if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; }
samer@0 139 }
samer@0 140 _s[j]=state;
samer@0 141 }
samer@0 142
samer@0 143 for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]);
samer@0 144
samer@0 145 L.value = Mathx.sum(_l);
samer@0 146 }
samer@0 147
samer@0 148 L.changed();
samer@0 149 t.changed();
samer@0 150 s.changed();
samer@0 151 l.changed();
samer@0 152 post.changed();
samer@0 153 }
samer@0 154
samer@0 155 public Task learnTask() {
samer@0 156 return new AnonymousTask() {
samer@0 157 // buffer changes to parameters
samer@0 158 double [][] _dw=dw.getArray();
samer@0 159 double [][] _dmu=dmu.getArray();
samer@0 160 double [][] _dsig=dsig.getArray();
samer@0 161
samer@0 162 public void starting() {
samer@0 163 dw.zero();
samer@0 164 dmu.zero();
samer@0 165 dsig.zero();
samer@0 166 normalise(w.getArray(),tmp);
samer@0 167 }
samer@0 168
samer@0 169 public void run() {
samer@0 170 for (int i=0; i<N; i++) {
samer@0 171 for (int j=0; j<M; j++) {
samer@0 172 double pp=_post[i][j], tt=_t[i][j];
samer@0 173 _dw[i][j] += pp;
samer@0 174 _dmu[i][j] += pp*tt;
samer@0 175 _dsig[i][j]+= pp*(tt*tt-1);
samer@0 176 }
samer@0 177 }
samer@0 178 }
samer@0 179 };
samer@0 180 }
samer@0 181
samer@0 182 public Task flushTask() {
samer@0 183 Shell.push(node);
samer@0 184
samer@0 185 try {
samer@0 186 return new AnonymousTask() {
samer@0 187
samer@0 188 VDouble nuw=new VDouble("weights.learn.rate",.001);
samer@0 189 VDouble numu=new VDouble("means.learn.rate",.001);
samer@0 190 VDouble nusig=new VDouble("sigmas.learn.rate",.001);
samer@0 191
samer@0 192 double [][] _w=w.getArray();
samer@0 193 double [][] _dw=dw.getArray();
samer@0 194
samer@0 195 public void run()
samer@0 196 {
samer@0 197 // mu += numu*dmu.*sig
samer@0 198 // sig += nusig*dsig.*sig;
samer@0 199 // lambda = sum(w.*dw,2)./sum(w.*w,2)
samer@0 200 // dw -= (lambda*ones(1,N)) .* w
samer@0 201 // w *= exp(nu*dw)
samer@0 202
samer@0 203 dmu.arrayTimesEquals(sig);
samer@0 204 dmu.timesEquals(nusig.value);
samer@0 205 mu.plusEquals(dmu);
samer@0 206 mu.changed();
samer@0 207 dmu.zero();
samer@0 208
samer@0 209 dsig.arrayTimesEquals(sig);
samer@0 210 dsig.timesEquals(nusig.value);
samer@0 211 sig.plusEquals(dsig);
samer@0 212 sig.changed();
samer@0 213 dsig.zero();
samer@0 214
samer@0 215 {
samer@0 216 // the effect of this is to project
samer@0 217 // dw away from w. the resulting vector
samer@0 218 // is then added to the log of w
samer@0 219
samer@0 220 double nu=nuw.value;
samer@0 221
samer@0 222 for (int j=0; j<M; j++) {
samer@0 223 double w2=0, lambda=0;
samer@0 224 for (int i=0; i<N; i++) {
samer@0 225 lambda += _w[i][j]*_dw[i][j];
samer@0 226 w2 += _w[i][j]*_w[i][j];
samer@0 227 }
samer@0 228 tmp[j]=lambda/w2;
samer@0 229 }
samer@0 230
samer@0 231 for (int i=0; i<N; i++) {
samer@0 232 for (int j=0; j<M; j++) {
samer@0 233 _w[i][j] *= Math.exp(
samer@0 234 nu*(_dw[i][j] - tmp[j]*_w[i][j])
samer@0 235 ); // update w
samer@0 236 }
samer@0 237 }
samer@0 238 }
samer@0 239
samer@0 240 normalise(w.getArray(),tmp);
samer@0 241 w.changed();
samer@0 242 dw.zero();
samer@0 243 }
samer@0 244 };
samer@0 245 } finally { Shell.pop(); }
samer@0 246 }
samer@0 247
samer@0 248 class PDF extends Function implements Observer {
samer@0 249 VInteger index;
samer@0 250 Viewable vbl=null;
samer@0 251
samer@0 252 public PDF() {
samer@0 253 index=new VInteger("index",M/2);
samer@0 254 index.setRange(0,M-1);
samer@0 255 index.addObserver(this);
samer@0 256 }
samer@0 257
samer@0 258 public void dispose() { index.dispose(); }
samer@0 259
samer@0 260 public double apply(double x) {
samer@0 261 double Z=0;
samer@0 262 for (int i=0; i<N; i++) {
samer@0 263 int j=index.value;
samer@0 264 double t=(x-_mu[i][j])/_sig[i][j];
samer@0 265 Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t);
samer@0 266 }
samer@0 267 return Z;
samer@0 268 }
samer@0 269 public String format(String x) {
samer@0 270 return "mogpdf("+x+")";
samer@0 271 }
samer@0 272
samer@0 273 public void setupObservers(Viewable v) {
samer@0 274 w.addObserver(this);
samer@0 275 mu.addObserver(this);
samer@0 276 sig.addObserver(this);
samer@0 277 vbl=v;
samer@0 278 }
samer@0 279
samer@0 280 public void update(Observable o, Object arg) {
samer@0 281 if (arg!=Viewable.DISPOSING) vbl.changed();
samer@0 282 }
samer@0 283 }
samer@0 284 }