comparison src/samer/models/MOGVector.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:bf79fb79ee13
1 package samer.models;
2
3 import samer.core.*;
4 import samer.core.types.*;
5 import samer.maths.*;
6 import samer.tools.*;
7 import java.util.*;
8
9 public class MOGVector extends NamedTask
10 {
11 Vec input;
12
13 int N, M; // num states, number of inputs
14 VVector s; // states
15 VVector l; // likelihoods
16 VDouble L; // total likelihood
17
18 Matrix t; // deviations: t=(x-mu)/sig
19 Matrix post; // posterior over states
20 Matrix w, mu, sig; // weights, means, stddevs
21 Matrix dw, dmu, dsig; // accumulated stats for learning
22
23 VDouble nuw, numu, nusig; // learning rates
24
25
26 // ------------ private bits ------------------------
27
28 private final static double Q=0.5*Math.log(2*Math.PI);
29 private double[] tmp;
30 private double [] _s;
31 private double [] _l;
32 private double [][] _t;
33 private double [][] _w;
34 private double [][] _mu;
35 private double [][] _sig;
36 private double [][] _post;
37 // ---------------------------------------------------
38
39 public MOGVector(Vec in, int n)
40 {
41 super("mog");
42 input=in;
43
44 Shell.push(node);
45 N=n; M=in.size();
46
47 s = new VVector("state",M);
48 l = new VVector("-log p(s)",M);
49 L = new VDouble("likelihood");
50
51 t = new Matrix("t",N,M);
52 post = new Matrix("p(s|x)",N,M);
53 mu = new Matrix("means",N,M);
54 sig = new Matrix("sigmas",N,M);
55 w = new Matrix("weight",N,M);
56 dmu = new Matrix("dmu",N,M);
57 dsig = new Matrix("dsig",N,M);
58 dw = new Matrix("dw",N,M);
59
60 tmp=new double[M];
61
62 Shell.pop();
63
64 // initialise parameters
65
66 // weights...
67 w.set(new Constant(1.0/N));
68 w.changed();
69
70 // means...
71 for (int i=0; i<N; i++) {
72 mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i));
73 }
74 mu.changed();
75
76 // sigmas
77 sig.set(new Constant(1));
78 sig.changed();
79
80
81 // ----- initialise private bits ----------
82 _s=s.array();
83 _l=l.array();
84 _t=t.getArray();
85 _w=w.getArray();
86 _mu=mu.getArray();
87 _sig=sig.getArray();
88 _post=post.getArray();
89 }
90
91 // normalise column sum of w, return array of sums
92 private void normalise(double [][] w, double[] sum) {
93 // make sure weights are normalised properly
94 Mathx.zero(sum);
95 for (int i=0; i<N; i++) Mathx.add(sum,w[i]);
96 for (int i=0; i<N; i++) Mathx.div(w[i],sum);
97 }
98
99 public VFunction getPDF() {
100 PDF fn=new PDF();
101 VFunction vfn=new VFunction("pdf",fn);
102 fn.setupObservers(vfn);
103 return vfn;
104 }
105
106 public void run()
107 {
108 // matlab equivalent code:
109 // t = (x - mu)./sig
110 // post = (w./sig).*exp(-0.5*t.^2)
111 // [dummy, s] = max(post,2)
112 // Z = sum(post,2)
113 // post = post./(Z*ones(1,N))
114 // Z=Z*Q;
115
116 {
117 // Vec.Iterator it=input.iterator();
118
119 double [] x=input.array();
120
121 for (int i=0; i<N; i++) {
122 for (int j=0; j<M; j++) {
123 _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j];
124 _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]);
125 }
126 }
127
128 // this computes partition function
129 // and normalises posterior in one go
130 normalise(_post,_l);
131
132 // get MAP state
133 for (int j=0; j<M; j++) {
134 int state=0;
135 double pmax=_post[0][j];
136
137 for (int i=1; i<N; i++) {
138 if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; }
139 }
140 _s[j]=state;
141 }
142
143 for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]);
144
145 L.value = Mathx.sum(_l);
146 }
147
148 L.changed();
149 t.changed();
150 s.changed();
151 l.changed();
152 post.changed();
153 }
154
155 public Task learnTask() {
156 return new AnonymousTask() {
157 // buffer changes to parameters
158 double [][] _dw=dw.getArray();
159 double [][] _dmu=dmu.getArray();
160 double [][] _dsig=dsig.getArray();
161
162 public void starting() {
163 dw.zero();
164 dmu.zero();
165 dsig.zero();
166 normalise(w.getArray(),tmp);
167 }
168
169 public void run() {
170 for (int i=0; i<N; i++) {
171 for (int j=0; j<M; j++) {
172 double pp=_post[i][j], tt=_t[i][j];
173 _dw[i][j] += pp;
174 _dmu[i][j] += pp*tt;
175 _dsig[i][j]+= pp*(tt*tt-1);
176 }
177 }
178 }
179 };
180 }
181
182 public Task flushTask() {
183 Shell.push(node);
184
185 try {
186 return new AnonymousTask() {
187
188 VDouble nuw=new VDouble("weights.learn.rate",.001);
189 VDouble numu=new VDouble("means.learn.rate",.001);
190 VDouble nusig=new VDouble("sigmas.learn.rate",.001);
191
192 double [][] _w=w.getArray();
193 double [][] _dw=dw.getArray();
194
195 public void run()
196 {
197 // mu += numu*dmu.*sig
198 // sig += nusig*dsig.*sig;
199 // lambda = sum(w.*dw,2)./sum(w.*w,2)
200 // dw -= (lambda*ones(1,N)) .* w
201 // w *= exp(nu*dw)
202
203 dmu.arrayTimesEquals(sig);
204 dmu.timesEquals(nusig.value);
205 mu.plusEquals(dmu);
206 mu.changed();
207 dmu.zero();
208
209 dsig.arrayTimesEquals(sig);
210 dsig.timesEquals(nusig.value);
211 sig.plusEquals(dsig);
212 sig.changed();
213 dsig.zero();
214
215 {
216 // the effect of this is to project
217 // dw away from w. the resulting vector
218 // is then added to the log of w
219
220 double nu=nuw.value;
221
222 for (int j=0; j<M; j++) {
223 double w2=0, lambda=0;
224 for (int i=0; i<N; i++) {
225 lambda += _w[i][j]*_dw[i][j];
226 w2 += _w[i][j]*_w[i][j];
227 }
228 tmp[j]=lambda/w2;
229 }
230
231 for (int i=0; i<N; i++) {
232 for (int j=0; j<M; j++) {
233 _w[i][j] *= Math.exp(
234 nu*(_dw[i][j] - tmp[j]*_w[i][j])
235 ); // update w
236 }
237 }
238 }
239
240 normalise(w.getArray(),tmp);
241 w.changed();
242 dw.zero();
243 }
244 };
245 } finally { Shell.pop(); }
246 }
247
248 class PDF extends Function implements Observer {
249 VInteger index;
250 Viewable vbl=null;
251
252 public PDF() {
253 index=new VInteger("index",M/2);
254 index.setRange(0,M-1);
255 index.addObserver(this);
256 }
257
258 public void dispose() { index.dispose(); }
259
260 public double apply(double x) {
261 double Z=0;
262 for (int i=0; i<N; i++) {
263 int j=index.value;
264 double t=(x-_mu[i][j])/_sig[i][j];
265 Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t);
266 }
267 return Z;
268 }
269 public String format(String x) {
270 return "mogpdf("+x+")";
271 }
272
273 public void setupObservers(Viewable v) {
274 w.addObserver(this);
275 mu.addObserver(this);
276 sig.addObserver(this);
277 vbl=v;
278 }
279
280 public void update(Observable o, Object arg) {
281 if (arg!=Viewable.DISPOSING) vbl.changed();
282 }
283 }
284 }