samer@0
|
1 package samer.models;
|
samer@0
|
2
|
samer@0
|
3 import samer.core.*;
|
samer@0
|
4 import samer.core.types.*;
|
samer@0
|
5 import samer.maths.*;
|
samer@0
|
6 import samer.tools.*;
|
samer@0
|
7 import java.util.*;
|
samer@0
|
8
|
samer@0
|
9 public class MOGVector extends NamedTask
|
samer@0
|
10 {
|
samer@0
|
11 Vec input;
|
samer@0
|
12
|
samer@0
|
13 int N, M; // num states, number of inputs
|
samer@0
|
14 VVector s; // states
|
samer@0
|
15 VVector l; // likelihoods
|
samer@0
|
16 VDouble L; // total likelihood
|
samer@0
|
17
|
samer@0
|
18 Matrix t; // deviations: t=(x-mu)/sig
|
samer@0
|
19 Matrix post; // posterior over states
|
samer@0
|
20 Matrix w, mu, sig; // weights, means, stddevs
|
samer@0
|
21 Matrix dw, dmu, dsig; // accumulated stats for learning
|
samer@0
|
22
|
samer@0
|
23 VDouble nuw, numu, nusig; // learning rates
|
samer@0
|
24
|
samer@0
|
25
|
samer@0
|
26 // ------------ private bits ------------------------
|
samer@0
|
27
|
samer@0
|
28 private final static double Q=0.5*Math.log(2*Math.PI);
|
samer@0
|
29 private double[] tmp;
|
samer@0
|
30 private double [] _s;
|
samer@0
|
31 private double [] _l;
|
samer@0
|
32 private double [][] _t;
|
samer@0
|
33 private double [][] _w;
|
samer@0
|
34 private double [][] _mu;
|
samer@0
|
35 private double [][] _sig;
|
samer@0
|
36 private double [][] _post;
|
samer@0
|
37 // ---------------------------------------------------
|
samer@0
|
38
|
samer@0
|
39 public MOGVector(Vec in, int n)
|
samer@0
|
40 {
|
samer@0
|
41 super("mog");
|
samer@0
|
42 input=in;
|
samer@0
|
43
|
samer@0
|
44 Shell.push(node);
|
samer@0
|
45 N=n; M=in.size();
|
samer@0
|
46
|
samer@0
|
47 s = new VVector("state",M);
|
samer@0
|
48 l = new VVector("-log p(s)",M);
|
samer@0
|
49 L = new VDouble("likelihood");
|
samer@0
|
50
|
samer@0
|
51 t = new Matrix("t",N,M);
|
samer@0
|
52 post = new Matrix("p(s|x)",N,M);
|
samer@0
|
53 mu = new Matrix("means",N,M);
|
samer@0
|
54 sig = new Matrix("sigmas",N,M);
|
samer@0
|
55 w = new Matrix("weight",N,M);
|
samer@0
|
56 dmu = new Matrix("dmu",N,M);
|
samer@0
|
57 dsig = new Matrix("dsig",N,M);
|
samer@0
|
58 dw = new Matrix("dw",N,M);
|
samer@0
|
59
|
samer@0
|
60 tmp=new double[M];
|
samer@0
|
61
|
samer@0
|
62 Shell.pop();
|
samer@0
|
63
|
samer@0
|
64 // initialise parameters
|
samer@0
|
65
|
samer@0
|
66 // weights...
|
samer@0
|
67 w.set(new Constant(1.0/N));
|
samer@0
|
68 w.changed();
|
samer@0
|
69
|
samer@0
|
70 // means...
|
samer@0
|
71 for (int i=0; i<N; i++) {
|
samer@0
|
72 mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i));
|
samer@0
|
73 }
|
samer@0
|
74 mu.changed();
|
samer@0
|
75
|
samer@0
|
76 // sigmas
|
samer@0
|
77 sig.set(new Constant(1));
|
samer@0
|
78 sig.changed();
|
samer@0
|
79
|
samer@0
|
80
|
samer@0
|
81 // ----- initialise private bits ----------
|
samer@0
|
82 _s=s.array();
|
samer@0
|
83 _l=l.array();
|
samer@0
|
84 _t=t.getArray();
|
samer@0
|
85 _w=w.getArray();
|
samer@0
|
86 _mu=mu.getArray();
|
samer@0
|
87 _sig=sig.getArray();
|
samer@0
|
88 _post=post.getArray();
|
samer@0
|
89 }
|
samer@0
|
90
|
samer@0
|
91 // normalise column sum of w, return array of sums
|
samer@0
|
92 private void normalise(double [][] w, double[] sum) {
|
samer@0
|
93 // make sure weights are normalised properly
|
samer@0
|
94 Mathx.zero(sum);
|
samer@0
|
95 for (int i=0; i<N; i++) Mathx.add(sum,w[i]);
|
samer@0
|
96 for (int i=0; i<N; i++) Mathx.div(w[i],sum);
|
samer@0
|
97 }
|
samer@0
|
98
|
samer@0
|
99 public VFunction getPDF() {
|
samer@0
|
100 PDF fn=new PDF();
|
samer@0
|
101 VFunction vfn=new VFunction("pdf",fn);
|
samer@0
|
102 fn.setupObservers(vfn);
|
samer@0
|
103 return vfn;
|
samer@0
|
104 }
|
samer@0
|
105
|
samer@0
|
106 public void run()
|
samer@0
|
107 {
|
samer@0
|
108 // matlab equivalent code:
|
samer@0
|
109 // t = (x - mu)./sig
|
samer@0
|
110 // post = (w./sig).*exp(-0.5*t.^2)
|
samer@0
|
111 // [dummy, s] = max(post,2)
|
samer@0
|
112 // Z = sum(post,2)
|
samer@0
|
113 // post = post./(Z*ones(1,N))
|
samer@0
|
114 // Z=Z*Q;
|
samer@0
|
115
|
samer@0
|
116 {
|
samer@0
|
117 // Vec.Iterator it=input.iterator();
|
samer@0
|
118
|
samer@0
|
119 double [] x=input.array();
|
samer@0
|
120
|
samer@0
|
121 for (int i=0; i<N; i++) {
|
samer@0
|
122 for (int j=0; j<M; j++) {
|
samer@0
|
123 _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j];
|
samer@0
|
124 _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]);
|
samer@0
|
125 }
|
samer@0
|
126 }
|
samer@0
|
127
|
samer@0
|
128 // this computes partition function
|
samer@0
|
129 // and normalises posterior in one go
|
samer@0
|
130 normalise(_post,_l);
|
samer@0
|
131
|
samer@0
|
132 // get MAP state
|
samer@0
|
133 for (int j=0; j<M; j++) {
|
samer@0
|
134 int state=0;
|
samer@0
|
135 double pmax=_post[0][j];
|
samer@0
|
136
|
samer@0
|
137 for (int i=1; i<N; i++) {
|
samer@0
|
138 if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; }
|
samer@0
|
139 }
|
samer@0
|
140 _s[j]=state;
|
samer@0
|
141 }
|
samer@0
|
142
|
samer@0
|
143 for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]);
|
samer@0
|
144
|
samer@0
|
145 L.value = Mathx.sum(_l);
|
samer@0
|
146 }
|
samer@0
|
147
|
samer@0
|
148 L.changed();
|
samer@0
|
149 t.changed();
|
samer@0
|
150 s.changed();
|
samer@0
|
151 l.changed();
|
samer@0
|
152 post.changed();
|
samer@0
|
153 }
|
samer@0
|
154
|
samer@0
|
155 public Task learnTask() {
|
samer@0
|
156 return new AnonymousTask() {
|
samer@0
|
157 // buffer changes to parameters
|
samer@0
|
158 double [][] _dw=dw.getArray();
|
samer@0
|
159 double [][] _dmu=dmu.getArray();
|
samer@0
|
160 double [][] _dsig=dsig.getArray();
|
samer@0
|
161
|
samer@0
|
162 public void starting() {
|
samer@0
|
163 dw.zero();
|
samer@0
|
164 dmu.zero();
|
samer@0
|
165 dsig.zero();
|
samer@0
|
166 normalise(w.getArray(),tmp);
|
samer@0
|
167 }
|
samer@0
|
168
|
samer@0
|
169 public void run() {
|
samer@0
|
170 for (int i=0; i<N; i++) {
|
samer@0
|
171 for (int j=0; j<M; j++) {
|
samer@0
|
172 double pp=_post[i][j], tt=_t[i][j];
|
samer@0
|
173 _dw[i][j] += pp;
|
samer@0
|
174 _dmu[i][j] += pp*tt;
|
samer@0
|
175 _dsig[i][j]+= pp*(tt*tt-1);
|
samer@0
|
176 }
|
samer@0
|
177 }
|
samer@0
|
178 }
|
samer@0
|
179 };
|
samer@0
|
180 }
|
samer@0
|
181
|
samer@0
|
182 public Task flushTask() {
|
samer@0
|
183 Shell.push(node);
|
samer@0
|
184
|
samer@0
|
185 try {
|
samer@0
|
186 return new AnonymousTask() {
|
samer@0
|
187
|
samer@0
|
188 VDouble nuw=new VDouble("weights.learn.rate",.001);
|
samer@0
|
189 VDouble numu=new VDouble("means.learn.rate",.001);
|
samer@0
|
190 VDouble nusig=new VDouble("sigmas.learn.rate",.001);
|
samer@0
|
191
|
samer@0
|
192 double [][] _w=w.getArray();
|
samer@0
|
193 double [][] _dw=dw.getArray();
|
samer@0
|
194
|
samer@0
|
195 public void run()
|
samer@0
|
196 {
|
samer@0
|
197 // mu += numu*dmu.*sig
|
samer@0
|
198 // sig += nusig*dsig.*sig;
|
samer@0
|
199 // lambda = sum(w.*dw,2)./sum(w.*w,2)
|
samer@0
|
200 // dw -= (lambda*ones(1,N)) .* w
|
samer@0
|
201 // w *= exp(nu*dw)
|
samer@0
|
202
|
samer@0
|
203 dmu.arrayTimesEquals(sig);
|
samer@0
|
204 dmu.timesEquals(nusig.value);
|
samer@0
|
205 mu.plusEquals(dmu);
|
samer@0
|
206 mu.changed();
|
samer@0
|
207 dmu.zero();
|
samer@0
|
208
|
samer@0
|
209 dsig.arrayTimesEquals(sig);
|
samer@0
|
210 dsig.timesEquals(nusig.value);
|
samer@0
|
211 sig.plusEquals(dsig);
|
samer@0
|
212 sig.changed();
|
samer@0
|
213 dsig.zero();
|
samer@0
|
214
|
samer@0
|
215 {
|
samer@0
|
216 // the effect of this is to project
|
samer@0
|
217 // dw away from w. the resulting vector
|
samer@0
|
218 // is then added to the log of w
|
samer@0
|
219
|
samer@0
|
220 double nu=nuw.value;
|
samer@0
|
221
|
samer@0
|
222 for (int j=0; j<M; j++) {
|
samer@0
|
223 double w2=0, lambda=0;
|
samer@0
|
224 for (int i=0; i<N; i++) {
|
samer@0
|
225 lambda += _w[i][j]*_dw[i][j];
|
samer@0
|
226 w2 += _w[i][j]*_w[i][j];
|
samer@0
|
227 }
|
samer@0
|
228 tmp[j]=lambda/w2;
|
samer@0
|
229 }
|
samer@0
|
230
|
samer@0
|
231 for (int i=0; i<N; i++) {
|
samer@0
|
232 for (int j=0; j<M; j++) {
|
samer@0
|
233 _w[i][j] *= Math.exp(
|
samer@0
|
234 nu*(_dw[i][j] - tmp[j]*_w[i][j])
|
samer@0
|
235 ); // update w
|
samer@0
|
236 }
|
samer@0
|
237 }
|
samer@0
|
238 }
|
samer@0
|
239
|
samer@0
|
240 normalise(w.getArray(),tmp);
|
samer@0
|
241 w.changed();
|
samer@0
|
242 dw.zero();
|
samer@0
|
243 }
|
samer@0
|
244 };
|
samer@0
|
245 } finally { Shell.pop(); }
|
samer@0
|
246 }
|
samer@0
|
247
|
samer@0
|
248 class PDF extends Function implements Observer {
|
samer@0
|
249 VInteger index;
|
samer@0
|
250 Viewable vbl=null;
|
samer@0
|
251
|
samer@0
|
252 public PDF() {
|
samer@0
|
253 index=new VInteger("index",M/2);
|
samer@0
|
254 index.setRange(0,M-1);
|
samer@0
|
255 index.addObserver(this);
|
samer@0
|
256 }
|
samer@0
|
257
|
samer@0
|
258 public void dispose() { index.dispose(); }
|
samer@0
|
259
|
samer@0
|
260 public double apply(double x) {
|
samer@0
|
261 double Z=0;
|
samer@0
|
262 for (int i=0; i<N; i++) {
|
samer@0
|
263 int j=index.value;
|
samer@0
|
264 double t=(x-_mu[i][j])/_sig[i][j];
|
samer@0
|
265 Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t);
|
samer@0
|
266 }
|
samer@0
|
267 return Z;
|
samer@0
|
268 }
|
samer@0
|
269 public String format(String x) {
|
samer@0
|
270 return "mogpdf("+x+")";
|
samer@0
|
271 }
|
samer@0
|
272
|
samer@0
|
273 public void setupObservers(Viewable v) {
|
samer@0
|
274 w.addObserver(this);
|
samer@0
|
275 mu.addObserver(this);
|
samer@0
|
276 sig.addObserver(this);
|
samer@0
|
277 vbl=v;
|
samer@0
|
278 }
|
samer@0
|
279
|
samer@0
|
280 public void update(Observable o, Object arg) {
|
samer@0
|
281 if (arg!=Viewable.DISPOSING) vbl.changed();
|
samer@0
|
282 }
|
samer@0
|
283 }
|
samer@0
|
284 }
|