samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.Agent.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.tools.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16 import samer.maths.opt.*;
|
samer@0
|
17
|
samer@0
|
18 public class ICA extends Viewable implements SafeTask, Model, Agent
|
samer@0
|
19 {
|
samer@0
|
20 Model sourceModel;
|
samer@0
|
21 int n;
|
samer@0
|
22 Vec x;
|
samer@0
|
23 Matrix W, A;
|
samer@0
|
24 VVector s; // state
|
samer@0
|
25 VDouble logA;
|
samer@0
|
26 double[] _g; // cached dE/dx
|
samer@0
|
27
|
samer@0
|
28 MatrixTimesVector infer;
|
samer@0
|
29 MatrixTransposeTimesVector grad;
|
samer@0
|
30
|
samer@0
|
31 public ICA(Vec input) { this(input.size()); setInput(input); }
|
samer@0
|
32 public ICA(int N) { this(new Node("ica"),N); }
|
samer@0
|
33 public ICA(Node node, int N)
|
samer@0
|
34 {
|
samer@0
|
35 super(node);
|
samer@0
|
36 Shell.push(node);
|
samer@0
|
37
|
samer@0
|
38 n = N;
|
samer@0
|
39
|
samer@0
|
40 x = null;
|
samer@0
|
41 W = new Matrix("W",n,n);
|
samer@0
|
42 A = new Matrix("A",n,n); // should we defer creation of this?
|
samer@0
|
43 s = new VVector("s",n);
|
samer@0
|
44 logA = new VDouble( "log |A|");
|
samer@0
|
45 infer = null;
|
samer@0
|
46 Shell.pop();
|
samer@0
|
47
|
samer@0
|
48 W.identity(); W.changed();
|
samer@0
|
49 grad = new MatrixTransposeTimesVector(W);
|
samer@0
|
50 _g = new double[n];
|
samer@0
|
51
|
samer@0
|
52 setAgent(this);
|
samer@0
|
53 }
|
samer@0
|
54
|
samer@0
|
55 public int getSize() { return n; }
|
samer@0
|
56 public VVector output() { return s; }
|
samer@0
|
57 public Model getOutputModel() { return sourceModel; }
|
samer@0
|
58 public void setOutputModel(Model m) { sourceModel=m; }
|
samer@0
|
59 public void setInput(Vec input) {
|
samer@0
|
60 if (input.size()!=n) throw new Error("Input vector is the wrong size");
|
samer@0
|
61 x = input; infer = new MatrixTimesVector(s,W,x);
|
samer@0
|
62 }
|
samer@0
|
63
|
samer@0
|
64 public Matrix getWeightMatrix() { return W; }
|
samer@0
|
65 public Matrix getBasisMatrix() { return A; }
|
samer@0
|
66
|
samer@0
|
67 public void dispose()
|
samer@0
|
68 {
|
samer@0
|
69 W.dispose(); A.dispose();
|
samer@0
|
70 logA.dispose();
|
samer@0
|
71 s.dispose();
|
samer@0
|
72 infer.dispose();
|
samer@0
|
73 grad.dispose();
|
samer@0
|
74
|
samer@0
|
75 super.dispose();
|
samer@0
|
76 }
|
samer@0
|
77
|
samer@0
|
78 public void infer() { // compute s=Wx
|
samer@0
|
79 infer.run();
|
samer@0
|
80 s.changed();
|
samer@0
|
81 }
|
samer@0
|
82
|
samer@0
|
83 public void compute() {
|
samer@0
|
84 grad.apply(sourceModel.getGradient(),_g);
|
samer@0
|
85 }
|
samer@0
|
86
|
samer@0
|
87 public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
|
samer@0
|
88 public double[] getGradient() { return _g; }
|
samer@0
|
89
|
samer@0
|
90 public Functionx functionx() {
|
samer@0
|
91 return new Functionx() {
|
samer@0
|
92 Functionx fM=sourceModel.functionx();
|
samer@0
|
93 double [] s=ICA.this.s.array(); //new double[n];
|
samer@0
|
94 double [] gs=new double[n];
|
samer@0
|
95
|
samer@0
|
96 public void dispose() { fM.dispose(); }
|
samer@0
|
97 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
|
samer@0
|
98 public double evaluate(double [] x, double [] g) {
|
samer@0
|
99 infer.apply(x,s); double E=fM.evaluate(s,gs);
|
samer@0
|
100 grad.apply(gs,g); return E+logA.value;
|
samer@0
|
101 }
|
samer@0
|
102 };
|
samer@0
|
103 }
|
samer@0
|
104
|
samer@0
|
105 public void starting() {}
|
samer@0
|
106 public void stopping() {}
|
samer@0
|
107 public void run() { infer(); }
|
samer@0
|
108
|
samer@0
|
109 public void getCommands(Registry r) { r.add("basis").add("logdet"); }
|
samer@0
|
110 public void execute(String cmd, Environment env)
|
samer@0
|
111 {
|
samer@0
|
112 if (cmd.equals("basis")) {
|
samer@0
|
113 Shell.print("computing ICA basis...");
|
samer@0
|
114 A.assign(W.inverse());
|
samer@0
|
115 A.changed();
|
samer@0
|
116 Shell.print("...done.");
|
samer@0
|
117 } else if (cmd.equals("logdet")) {
|
samer@0
|
118 logA.set(-W.logdet());
|
samer@0
|
119 }
|
samer@0
|
120 }
|
samer@0
|
121
|
samer@0
|
122 public Trainer getTrainer() { return new ON2Trainer(); }
|
samer@0
|
123 public Trainer getAltTrainer() { return new ON3Trainer(); }
|
samer@0
|
124 public Trainer getDecayWhenActiveTrainer() { return new ON2DecayWhenActive(); }
|
samer@0
|
125
|
samer@0
|
126 /** This trainer uses an O(N^2) run step and an O(N^2) flush. */
|
samer@0
|
127
|
samer@0
|
128 public class ON2Trainer implements Model.Trainer
|
samer@0
|
129 {
|
samer@0
|
130 double[][] _W, _GW;
|
samer@0
|
131 double[] _s, buf;
|
samer@0
|
132 VDouble rate;
|
samer@0
|
133 Matrix GW;
|
samer@0
|
134 int _n;
|
samer@0
|
135 double thresh, batch;
|
samer@0
|
136 MatrixTransposeTimesVector sW;
|
samer@0
|
137
|
samer@0
|
138 public ON2Trainer()
|
samer@0
|
139 {
|
samer@0
|
140 _n=n;
|
samer@0
|
141 GW=new Matrix("GW",n,n);
|
samer@0
|
142 rate=new VDouble("rate",0.01);
|
samer@0
|
143 thresh=Shell.getDouble("anomaly",n*20);
|
samer@0
|
144 batch=0;
|
samer@0
|
145
|
samer@0
|
146 _s=s.array();
|
samer@0
|
147 _GW=GW.getArray();
|
samer@0
|
148 _W=W.getArray();
|
samer@0
|
149 buf=new double[n]; // general purpose n-buffer
|
samer@0
|
150 sW = new MatrixTransposeTimesVector(buf,W,_s);
|
samer@0
|
151 }
|
samer@0
|
152
|
samer@0
|
153 public String toString() { return "ON2Trainer:"+ICA.this; }
|
samer@0
|
154 public void dispose() { GW.dispose(); rate.dispose(); }
|
samer@0
|
155 public void reset() { GW.zero(); GW.changed(); batch=0; }
|
samer@0
|
156 public void oneshot() { accumulate(1); flush(); }
|
samer@0
|
157 public void accumulate() { accumulate(1); }
|
samer@0
|
158 public void accumulate(double w)
|
samer@0
|
159 {
|
samer@0
|
160 // HACK!
|
samer@0
|
161 if (sourceModel.getEnergy()>thresh) return;
|
samer@0
|
162
|
samer@0
|
163 double q, p[];
|
samer@0
|
164 batch+=w;
|
samer@0
|
165
|
samer@0
|
166 sW.run(); // buf = W'*s
|
samer@0
|
167 double[] phi=sourceModel.getGradient();
|
samer@0
|
168 for (int i=0; i<_n; i++) {
|
samer@0
|
169 double [] r=_W[i];
|
samer@0
|
170 // p = _GW[i]; q=phi[i];
|
samer@0
|
171 // for (int j=0; j<_n; j++) p[j] += w*(q*buf[j] - r[j]);
|
samer@0
|
172 p = _GW[i]; q=w*phi[i];
|
samer@0
|
173 for (int j=0; j<_n; j++) p[j] += q*buf[j];
|
samer@0
|
174 }
|
samer@0
|
175 }
|
samer@0
|
176
|
samer@0
|
177 public void flush()
|
samer@0
|
178 {
|
samer@0
|
179 if (batch==0) return;
|
samer@0
|
180 double eta=-rate.value/batch;
|
samer@0
|
181
|
samer@0
|
182 // now W += eta*GW
|
samer@0
|
183 for (int i=0; i<_n; i++) {
|
samer@0
|
184 double [] p = _W[i], q = _GW[i];
|
samer@0
|
185 // for (int j=0; j<_n; j++) p[j] += eta*q[j];
|
samer@0
|
186 for (int j=0; j<_n; j++) p[j] += eta*(q[j]-batch*p[j]);
|
samer@0
|
187 // Mathx.mul(q,1.0/batch);
|
samer@0
|
188 }
|
samer@0
|
189 GW.changed();
|
samer@0
|
190
|
samer@0
|
191 // reset for next batch
|
samer@0
|
192 W.changed(); GW.zero();
|
samer@0
|
193 batch=0;
|
samer@0
|
194 }
|
samer@0
|
195 }
|
samer@0
|
196
|
samer@0
|
197 public class ON2DecayWhenActive extends ON2Trainer {
|
samer@0
|
198 VDouble th=new VDouble("thresh",0.0);
|
samer@0
|
199
|
samer@0
|
200 public String toString() { return "DecayWhenActiveTrainer:"+ICA.this; }
|
samer@0
|
201 public void accumulate(double w) {
|
samer@0
|
202 if (sourceModel.getEnergy()>thresh) return;
|
samer@0
|
203
|
samer@0
|
204 double q, p[];
|
samer@0
|
205 batch+=w;
|
samer@0
|
206
|
samer@0
|
207 sW.run(); // buf = W'*s
|
samer@0
|
208 double[] phi=sourceModel.getGradient();
|
samer@0
|
209 for (int i=0; i<_n; i++) {
|
samer@0
|
210 double [] r=_W[i];
|
samer@0
|
211 p = _GW[i]; q=w*phi[i];
|
samer@0
|
212 for (int j=0; j<_n; j++) p[j] += q*buf[j];
|
samer@0
|
213 }
|
samer@0
|
214
|
samer@0
|
215 // decay when active part
|
samer@0
|
216 double thresh=th.value;
|
samer@0
|
217 for (int j=0; j<_n; j++)
|
samer@0
|
218 if (isActive(_s[j],thresh))
|
samer@0
|
219 for (int i=0; i<_n; i++)
|
samer@0
|
220 _GW[i][j] -= _W[i][j];
|
samer@0
|
221 }
|
samer@0
|
222
|
samer@0
|
223 public void flush()
|
samer@0
|
224 {
|
samer@0
|
225 if (batch==0) return;
|
samer@0
|
226 double eta=-rate.value/batch;
|
samer@0
|
227
|
samer@0
|
228 // now W += eta*GW
|
samer@0
|
229 for (int i=0; i<_n; i++) {
|
samer@0
|
230 double [] p = _W[i], q = _GW[i];
|
samer@0
|
231 for (int j=0; j<_n; j++) p[j] += eta*q[j];
|
samer@0
|
232 }
|
samer@0
|
233 GW.changed();
|
samer@0
|
234 W.changed(); GW.zero();
|
samer@0
|
235 batch=0;
|
samer@0
|
236 }
|
samer@0
|
237 }
|
samer@0
|
238 static boolean isActive(double s, double t) { return s>=t || s<=-t; }
|
samer@0
|
239
|
samer@0
|
240 /** This trainer saves on an O(N^2) step during accumulation, at
|
samer@0
|
241 the expense of an O(N^3) flush. As long as the batch size
|
samer@0
|
242 is O(N), then it should be about the same overall. The advantage
|
samer@0
|
243 is the collected statistics are more transparent, and can be used
|
samer@0
|
244 to make scalar or diagonal updates more frequenty. */
|
samer@0
|
245
|
samer@0
|
246
|
samer@0
|
247 public class ON3Trainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
248 {
|
samer@0
|
249 Matrix G;
|
samer@0
|
250 double[][] _G, _W;
|
samer@0
|
251 double[] _s, buf;
|
samer@0
|
252 VDouble rate;
|
samer@0
|
253 int _n;
|
samer@0
|
254 double batch, thresh;
|
samer@0
|
255
|
samer@0
|
256 public ON3Trainer()
|
samer@0
|
257 {
|
samer@0
|
258 _n=n;
|
samer@0
|
259 G=new Matrix("G",n,n);
|
samer@0
|
260 rate=new VDouble("rate",0.01);
|
samer@0
|
261 thresh=Shell.getDouble("anomaly",20*n);
|
samer@0
|
262 batch=0;
|
samer@0
|
263
|
samer@0
|
264 _s=s.array();
|
samer@0
|
265 _G=G.getArray();
|
samer@0
|
266 _W=W.getArray();
|
samer@0
|
267 buf=new double[n]; // general purpose n-buffer
|
samer@0
|
268 }
|
samer@0
|
269
|
samer@0
|
270 public String toString() { return "ON3Trainer:"+ICA.this; }
|
samer@0
|
271
|
samer@0
|
272 /** this is so you can manipulate the matrix before flushing */
|
samer@0
|
273 public Matrix getGMatrix() { return G; }
|
samer@0
|
274
|
samer@0
|
275 public void starting() { reset(); }
|
samer@0
|
276 public void run() { accumulate(); }
|
samer@0
|
277
|
samer@0
|
278 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
|
samer@0
|
279 public void oneshot() { accumulate(1); flush(); }
|
samer@0
|
280 public void reset() { G.zero(); batch=0; }
|
samer@0
|
281 public void accumulate() { accumulate(1); }
|
samer@0
|
282 public void accumulate(double w)
|
samer@0
|
283 {
|
samer@0
|
284 // HACK!
|
samer@0
|
285 if (sourceModel.getEnergy()>thresh) return;
|
samer@0
|
286
|
samer@0
|
287 double q, p[];
|
samer@0
|
288 batch+=w;
|
samer@0
|
289
|
samer@0
|
290 double[] phi=sourceModel.getGradient();
|
samer@0
|
291 for (int i=0; i<_n; i++) {
|
samer@0
|
292 p = _G[i]; q=w*phi[i];
|
samer@0
|
293 // if (Double.isNaN(q)) throw new Error("NAN"+i);
|
samer@0
|
294 for (int j=0; j<_n; j++) p[j] += q*_s[j];
|
samer@0
|
295 p[i] -= w;
|
samer@0
|
296 }
|
samer@0
|
297 }
|
samer@0
|
298
|
samer@0
|
299 public void flush()
|
samer@0
|
300 {
|
samer@0
|
301 if (batch==0) return;
|
samer@0
|
302 double eta=-rate.value/batch;
|
samer@0
|
303
|
samer@0
|
304 G.changed();
|
samer@0
|
305
|
samer@0
|
306 // this is going to do a matrix G *= W, in place
|
samer@0
|
307 for (int i=0; i<_n; i++) {
|
samer@0
|
308 for (int j=0; j<_n; j++) {
|
samer@0
|
309 double a=0;
|
samer@0
|
310 for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
|
samer@0
|
311 buf[j] = a;
|
samer@0
|
312 }
|
samer@0
|
313 Mathx.copy(buf,_G[i]);
|
samer@0
|
314 }
|
samer@0
|
315
|
samer@0
|
316 // now W += eta*G
|
samer@0
|
317 for (int i=0; i<_n; i++) {
|
samer@0
|
318 double [] p = _W[i], q = _G[i];
|
samer@0
|
319 for (int j=0; j<n; j++) p[j] += eta*q[j];
|
samer@0
|
320 }
|
samer@0
|
321
|
samer@0
|
322 reset(); // ready for next batch
|
samer@0
|
323 }
|
samer@0
|
324 }
|
samer@0
|
325
|
samer@0
|
326 // See Hyvarinen's paper that Nick gave me. Not finished
|
samer@0
|
327 public class NewtonTrainer extends ON3Trainer
|
samer@0
|
328 {
|
samer@0
|
329 Function dgamma;
|
samer@0
|
330 VVector f; // scnd derivatives of log prior
|
samer@0
|
331 double[] _f;
|
samer@0
|
332
|
samer@0
|
333 public NewtonTrainer(Function dg)
|
samer@0
|
334 {
|
samer@0
|
335 dgamma=dg;
|
samer@0
|
336 f=new VVector("f",n);
|
samer@0
|
337 _f=f.array();
|
samer@0
|
338 }
|
samer@0
|
339
|
samer@0
|
340 public void dispose() { f.dispose(); dgamma.dispose(); super.dispose(); }
|
samer@0
|
341 public void reset() { Mathx.zero(_f); super.reset(); }
|
samer@0
|
342 public void accumulate(double w)
|
samer@0
|
343 {
|
samer@0
|
344 // HACK!
|
samer@0
|
345 if (sourceModel.getEnergy()>thresh) return;
|
samer@0
|
346
|
samer@0
|
347 dgamma.apply(_s,buf);
|
samer@0
|
348 Mathx.add(_f,buf);
|
samer@0
|
349 super.accumulate(w);
|
samer@0
|
350 }
|
samer@0
|
351
|
samer@0
|
352 public void flush()
|
samer@0
|
353 {
|
samer@0
|
354 if (batch==0) return;
|
samer@0
|
355
|
samer@0
|
356 // first do some things to G, then do normal flush
|
samer@0
|
357
|
samer@0
|
358 super.flush();
|
samer@0
|
359 }
|
samer@0
|
360 }
|
samer@0
|
361 }
|
samer@0
|
362
|