Mercurial > hg > jslab
comparison src/samer/models/MOGVector.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:bf79fb79ee13 |
---|---|
1 package samer.models; | |
2 | |
3 import samer.core.*; | |
4 import samer.core.types.*; | |
5 import samer.maths.*; | |
6 import samer.tools.*; | |
7 import java.util.*; | |
8 | |
9 public class MOGVector extends NamedTask | |
10 { | |
11 Vec input; | |
12 | |
13 int N, M; // num states, number of inputs | |
14 VVector s; // states | |
15 VVector l; // likelihoods | |
16 VDouble L; // total likelihood | |
17 | |
18 Matrix t; // deviations: t=(x-mu)/sig | |
19 Matrix post; // posterior over states | |
20 Matrix w, mu, sig; // weights, means, stddevs | |
21 Matrix dw, dmu, dsig; // accumulated stats for learning | |
22 | |
23 VDouble nuw, numu, nusig; // learning rates | |
24 | |
25 | |
26 // ------------ private bits ------------------------ | |
27 | |
28 private final static double Q=0.5*Math.log(2*Math.PI); | |
29 private double[] tmp; | |
30 private double [] _s; | |
31 private double [] _l; | |
32 private double [][] _t; | |
33 private double [][] _w; | |
34 private double [][] _mu; | |
35 private double [][] _sig; | |
36 private double [][] _post; | |
37 // --------------------------------------------------- | |
38 | |
39 public MOGVector(Vec in, int n) | |
40 { | |
41 super("mog"); | |
42 input=in; | |
43 | |
44 Shell.push(node); | |
45 N=n; M=in.size(); | |
46 | |
47 s = new VVector("state",M); | |
48 l = new VVector("-log p(s)",M); | |
49 L = new VDouble("likelihood"); | |
50 | |
51 t = new Matrix("t",N,M); | |
52 post = new Matrix("p(s|x)",N,M); | |
53 mu = new Matrix("means",N,M); | |
54 sig = new Matrix("sigmas",N,M); | |
55 w = new Matrix("weight",N,M); | |
56 dmu = new Matrix("dmu",N,M); | |
57 dsig = new Matrix("dsig",N,M); | |
58 dw = new Matrix("dw",N,M); | |
59 | |
60 tmp=new double[M]; | |
61 | |
62 Shell.pop(); | |
63 | |
64 // initialise parameters | |
65 | |
66 // weights... | |
67 w.set(new Constant(1.0/N)); | |
68 w.changed(); | |
69 | |
70 // means... | |
71 for (int i=0; i<N; i++) { | |
72 mu.setMatrix(i,i,0,M-1,new Jama.Matrix(1,M,(double)i)); | |
73 } | |
74 mu.changed(); | |
75 | |
76 // sigmas | |
77 sig.set(new Constant(1)); | |
78 sig.changed(); | |
79 | |
80 | |
81 // ----- initialise private bits ---------- | |
82 _s=s.array(); | |
83 _l=l.array(); | |
84 _t=t.getArray(); | |
85 _w=w.getArray(); | |
86 _mu=mu.getArray(); | |
87 _sig=sig.getArray(); | |
88 _post=post.getArray(); | |
89 } | |
90 | |
91 // normalise column sum of w, return array of sums | |
92 private void normalise(double [][] w, double[] sum) { | |
93 // make sure weights are normalised properly | |
94 Mathx.zero(sum); | |
95 for (int i=0; i<N; i++) Mathx.add(sum,w[i]); | |
96 for (int i=0; i<N; i++) Mathx.div(w[i],sum); | |
97 } | |
98 | |
99 public VFunction getPDF() { | |
100 PDF fn=new PDF(); | |
101 VFunction vfn=new VFunction("pdf",fn); | |
102 fn.setupObservers(vfn); | |
103 return vfn; | |
104 } | |
105 | |
106 public void run() | |
107 { | |
108 // matlab equivalent code: | |
109 // t = (x - mu)./sig | |
110 // post = (w./sig).*exp(-0.5*t.^2) | |
111 // [dummy, s] = max(post,2) | |
112 // Z = sum(post,2) | |
113 // post = post./(Z*ones(1,N)) | |
114 // Z=Z*Q; | |
115 | |
116 { | |
117 // Vec.Iterator it=input.iterator(); | |
118 | |
119 double [] x=input.array(); | |
120 | |
121 for (int i=0; i<N; i++) { | |
122 for (int j=0; j<M; j++) { | |
123 _t[i][j] = (x[j]-_mu[i][j])/_sig[i][j]; | |
124 _post[i][j] = (_w[i][j]/_sig[i][j])*Math.exp(-0.5*_t[i][j]*_t[i][j]); | |
125 } | |
126 } | |
127 | |
128 // this computes partition function | |
129 // and normalises posterior in one go | |
130 normalise(_post,_l); | |
131 | |
132 // get MAP state | |
133 for (int j=0; j<M; j++) { | |
134 int state=0; | |
135 double pmax=_post[0][j]; | |
136 | |
137 for (int i=1; i<N; i++) { | |
138 if (_post[i][j]>pmax) { state=i; pmax=_post[i][j]; } | |
139 } | |
140 _s[j]=state; | |
141 } | |
142 | |
143 for (int j=0; j<M; j++) _l[j]=-Math.log(_l[j]); | |
144 | |
145 L.value = Mathx.sum(_l); | |
146 } | |
147 | |
148 L.changed(); | |
149 t.changed(); | |
150 s.changed(); | |
151 l.changed(); | |
152 post.changed(); | |
153 } | |
154 | |
155 public Task learnTask() { | |
156 return new AnonymousTask() { | |
157 // buffer changes to parameters | |
158 double [][] _dw=dw.getArray(); | |
159 double [][] _dmu=dmu.getArray(); | |
160 double [][] _dsig=dsig.getArray(); | |
161 | |
162 public void starting() { | |
163 dw.zero(); | |
164 dmu.zero(); | |
165 dsig.zero(); | |
166 normalise(w.getArray(),tmp); | |
167 } | |
168 | |
169 public void run() { | |
170 for (int i=0; i<N; i++) { | |
171 for (int j=0; j<M; j++) { | |
172 double pp=_post[i][j], tt=_t[i][j]; | |
173 _dw[i][j] += pp; | |
174 _dmu[i][j] += pp*tt; | |
175 _dsig[i][j]+= pp*(tt*tt-1); | |
176 } | |
177 } | |
178 } | |
179 }; | |
180 } | |
181 | |
182 public Task flushTask() { | |
183 Shell.push(node); | |
184 | |
185 try { | |
186 return new AnonymousTask() { | |
187 | |
188 VDouble nuw=new VDouble("weights.learn.rate",.001); | |
189 VDouble numu=new VDouble("means.learn.rate",.001); | |
190 VDouble nusig=new VDouble("sigmas.learn.rate",.001); | |
191 | |
192 double [][] _w=w.getArray(); | |
193 double [][] _dw=dw.getArray(); | |
194 | |
195 public void run() | |
196 { | |
197 // mu += numu*dmu.*sig | |
198 // sig += nusig*dsig.*sig; | |
199 // lambda = sum(w.*dw,2)./sum(w.*w,2) | |
200 // dw -= (lambda*ones(1,N)) .* w | |
201 // w *= exp(nu*dw) | |
202 | |
203 dmu.arrayTimesEquals(sig); | |
204 dmu.timesEquals(nusig.value); | |
205 mu.plusEquals(dmu); | |
206 mu.changed(); | |
207 dmu.zero(); | |
208 | |
209 dsig.arrayTimesEquals(sig); | |
210 dsig.timesEquals(nusig.value); | |
211 sig.plusEquals(dsig); | |
212 sig.changed(); | |
213 dsig.zero(); | |
214 | |
215 { | |
216 // the effect of this is to project | |
217 // dw away from w. the resulting vector | |
218 // is then added to the log of w | |
219 | |
220 double nu=nuw.value; | |
221 | |
222 for (int j=0; j<M; j++) { | |
223 double w2=0, lambda=0; | |
224 for (int i=0; i<N; i++) { | |
225 lambda += _w[i][j]*_dw[i][j]; | |
226 w2 += _w[i][j]*_w[i][j]; | |
227 } | |
228 tmp[j]=lambda/w2; | |
229 } | |
230 | |
231 for (int i=0; i<N; i++) { | |
232 for (int j=0; j<M; j++) { | |
233 _w[i][j] *= Math.exp( | |
234 nu*(_dw[i][j] - tmp[j]*_w[i][j]) | |
235 ); // update w | |
236 } | |
237 } | |
238 } | |
239 | |
240 normalise(w.getArray(),tmp); | |
241 w.changed(); | |
242 dw.zero(); | |
243 } | |
244 }; | |
245 } finally { Shell.pop(); } | |
246 } | |
247 | |
248 class PDF extends Function implements Observer { | |
249 VInteger index; | |
250 Viewable vbl=null; | |
251 | |
252 public PDF() { | |
253 index=new VInteger("index",M/2); | |
254 index.setRange(0,M-1); | |
255 index.addObserver(this); | |
256 } | |
257 | |
258 public void dispose() { index.dispose(); } | |
259 | |
260 public double apply(double x) { | |
261 double Z=0; | |
262 for (int i=0; i<N; i++) { | |
263 int j=index.value; | |
264 double t=(x-_mu[i][j])/_sig[i][j]; | |
265 Z+=(_w[i][j]/_sig[i][j])*Math.exp(-0.5*t*t); | |
266 } | |
267 return Z; | |
268 } | |
269 public String format(String x) { | |
270 return "mogpdf("+x+")"; | |
271 } | |
272 | |
273 public void setupObservers(Viewable v) { | |
274 w.addObserver(this); | |
275 mu.addObserver(this); | |
276 sig.addObserver(this); | |
277 vbl=v; | |
278 } | |
279 | |
280 public void update(Observable o, Object arg) { | |
281 if (arg!=Viewable.DISPOSING) vbl.changed(); | |
282 } | |
283 } | |
284 } |