annotate src/samer/models/ICA.java @ 3:15b93db27c04

Get StreamSource to compile, update args for demo
author samer
date Fri, 05 Apr 2019 17:00:18 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2002, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.Agent.*;
samer@0 13 import samer.core.types.*;
samer@0 14 import samer.tools.*;
samer@0 15 import samer.maths.*;
samer@0 16 import samer.maths.opt.*;
samer@0 17
samer@0 18 public class ICA extends Viewable implements SafeTask, Model, Agent
samer@0 19 {
samer@0 20 Model sourceModel;
samer@0 21 int n;
samer@0 22 Vec x;
samer@0 23 Matrix W, A;
samer@0 24 VVector s; // state
samer@0 25 VDouble logA;
samer@0 26 double[] _g; // cached dE/dx
samer@0 27
samer@0 28 MatrixTimesVector infer;
samer@0 29 MatrixTransposeTimesVector grad;
samer@0 30
samer@0 31 public ICA(Vec input) { this(input.size()); setInput(input); }
samer@0 32 public ICA(int N) { this(new Node("ica"),N); }
samer@0 33 public ICA(Node node, int N)
samer@0 34 {
samer@0 35 super(node);
samer@0 36 Shell.push(node);
samer@0 37
samer@0 38 n = N;
samer@0 39
samer@0 40 x = null;
samer@0 41 W = new Matrix("W",n,n);
samer@0 42 A = new Matrix("A",n,n); // should we defer creation of this?
samer@0 43 s = new VVector("s",n);
samer@0 44 logA = new VDouble( "log |A|");
samer@0 45 infer = null;
samer@0 46 Shell.pop();
samer@0 47
samer@0 48 W.identity(); W.changed();
samer@0 49 grad = new MatrixTransposeTimesVector(W);
samer@0 50 _g = new double[n];
samer@0 51
samer@0 52 setAgent(this);
samer@0 53 }
samer@0 54
samer@0 55 public int getSize() { return n; }
samer@0 56 public VVector output() { return s; }
samer@0 57 public Model getOutputModel() { return sourceModel; }
samer@0 58 public void setOutputModel(Model m) { sourceModel=m; }
samer@0 59 public void setInput(Vec input) {
samer@0 60 if (input.size()!=n) throw new Error("Input vector is the wrong size");
samer@0 61 x = input; infer = new MatrixTimesVector(s,W,x);
samer@0 62 }
samer@0 63
samer@0 64 public Matrix getWeightMatrix() { return W; }
samer@0 65 public Matrix getBasisMatrix() { return A; }
samer@0 66
samer@0 67 public void dispose()
samer@0 68 {
samer@0 69 W.dispose(); A.dispose();
samer@0 70 logA.dispose();
samer@0 71 s.dispose();
samer@0 72 infer.dispose();
samer@0 73 grad.dispose();
samer@0 74
samer@0 75 super.dispose();
samer@0 76 }
samer@0 77
samer@0 78 public void infer() { // compute s=Wx
samer@0 79 infer.run();
samer@0 80 s.changed();
samer@0 81 }
samer@0 82
samer@0 83 public void compute() {
samer@0 84 grad.apply(sourceModel.getGradient(),_g);
samer@0 85 }
samer@0 86
samer@0 87 public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
samer@0 88 public double[] getGradient() { return _g; }
samer@0 89
samer@0 90 public Functionx functionx() {
samer@0 91 return new Functionx() {
samer@0 92 Functionx fM=sourceModel.functionx();
samer@0 93 double [] s=ICA.this.s.array(); //new double[n];
samer@0 94 double [] gs=new double[n];
samer@0 95
samer@0 96 public void dispose() { fM.dispose(); }
samer@0 97 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
samer@0 98 public double evaluate(double [] x, double [] g) {
samer@0 99 infer.apply(x,s); double E=fM.evaluate(s,gs);
samer@0 100 grad.apply(gs,g); return E+logA.value;
samer@0 101 }
samer@0 102 };
samer@0 103 }
samer@0 104
samer@0 105 public void starting() {}
samer@0 106 public void stopping() {}
samer@0 107 public void run() { infer(); }
samer@0 108
samer@0 109 public void getCommands(Registry r) { r.add("basis").add("logdet"); }
samer@0 110 public void execute(String cmd, Environment env)
samer@0 111 {
samer@0 112 if (cmd.equals("basis")) {
samer@0 113 Shell.print("computing ICA basis...");
samer@0 114 A.assign(W.inverse());
samer@0 115 A.changed();
samer@0 116 Shell.print("...done.");
samer@0 117 } else if (cmd.equals("logdet")) {
samer@0 118 logA.set(-W.logdet());
samer@0 119 }
samer@0 120 }
samer@0 121
samer@0 122 public Trainer getTrainer() { return new ON2Trainer(); }
samer@0 123 public Trainer getAltTrainer() { return new ON3Trainer(); }
samer@0 124 public Trainer getDecayWhenActiveTrainer() { return new ON2DecayWhenActive(); }
samer@0 125
samer@0 126 /** This trainer uses an O(N^2) run step and an O(N^2) flush. */
samer@0 127
samer@0 128 public class ON2Trainer implements Model.Trainer
samer@0 129 {
samer@0 130 double[][] _W, _GW;
samer@0 131 double[] _s, buf;
samer@0 132 VDouble rate;
samer@0 133 Matrix GW;
samer@0 134 int _n;
samer@0 135 double thresh, batch;
samer@0 136 MatrixTransposeTimesVector sW;
samer@0 137
samer@0 138 public ON2Trainer()
samer@0 139 {
samer@0 140 _n=n;
samer@0 141 GW=new Matrix("GW",n,n);
samer@0 142 rate=new VDouble("rate",0.01);
samer@0 143 thresh=Shell.getDouble("anomaly",n*20);
samer@0 144 batch=0;
samer@0 145
samer@0 146 _s=s.array();
samer@0 147 _GW=GW.getArray();
samer@0 148 _W=W.getArray();
samer@0 149 buf=new double[n]; // general purpose n-buffer
samer@0 150 sW = new MatrixTransposeTimesVector(buf,W,_s);
samer@0 151 }
samer@0 152
samer@0 153 public String toString() { return "ON2Trainer:"+ICA.this; }
samer@0 154 public void dispose() { GW.dispose(); rate.dispose(); }
samer@0 155 public void reset() { GW.zero(); GW.changed(); batch=0; }
samer@0 156 public void oneshot() { accumulate(1); flush(); }
samer@0 157 public void accumulate() { accumulate(1); }
samer@0 158 public void accumulate(double w)
samer@0 159 {
samer@0 160 // HACK!
samer@0 161 if (sourceModel.getEnergy()>thresh) return;
samer@0 162
samer@0 163 double q, p[];
samer@0 164 batch+=w;
samer@0 165
samer@0 166 sW.run(); // buf = W'*s
samer@0 167 double[] phi=sourceModel.getGradient();
samer@0 168 for (int i=0; i<_n; i++) {
samer@0 169 double [] r=_W[i];
samer@0 170 // p = _GW[i]; q=phi[i];
samer@0 171 // for (int j=0; j<_n; j++) p[j] += w*(q*buf[j] - r[j]);
samer@0 172 p = _GW[i]; q=w*phi[i];
samer@0 173 for (int j=0; j<_n; j++) p[j] += q*buf[j];
samer@0 174 }
samer@0 175 }
samer@0 176
samer@0 177 public void flush()
samer@0 178 {
samer@0 179 if (batch==0) return;
samer@0 180 double eta=-rate.value/batch;
samer@0 181
samer@0 182 // now W += eta*GW
samer@0 183 for (int i=0; i<_n; i++) {
samer@0 184 double [] p = _W[i], q = _GW[i];
samer@0 185 // for (int j=0; j<_n; j++) p[j] += eta*q[j];
samer@0 186 for (int j=0; j<_n; j++) p[j] += eta*(q[j]-batch*p[j]);
samer@0 187 // Mathx.mul(q,1.0/batch);
samer@0 188 }
samer@0 189 GW.changed();
samer@0 190
samer@0 191 // reset for next batch
samer@0 192 W.changed(); GW.zero();
samer@0 193 batch=0;
samer@0 194 }
samer@0 195 }
samer@0 196
samer@0 197 public class ON2DecayWhenActive extends ON2Trainer {
samer@0 198 VDouble th=new VDouble("thresh",0.0);
samer@0 199
samer@0 200 public String toString() { return "DecayWhenActiveTrainer:"+ICA.this; }
samer@0 201 public void accumulate(double w) {
samer@0 202 if (sourceModel.getEnergy()>thresh) return;
samer@0 203
samer@0 204 double q, p[];
samer@0 205 batch+=w;
samer@0 206
samer@0 207 sW.run(); // buf = W'*s
samer@0 208 double[] phi=sourceModel.getGradient();
samer@0 209 for (int i=0; i<_n; i++) {
samer@0 210 double [] r=_W[i];
samer@0 211 p = _GW[i]; q=w*phi[i];
samer@0 212 for (int j=0; j<_n; j++) p[j] += q*buf[j];
samer@0 213 }
samer@0 214
samer@0 215 // decay when active part
samer@0 216 double thresh=th.value;
samer@0 217 for (int j=0; j<_n; j++)
samer@0 218 if (isActive(_s[j],thresh))
samer@0 219 for (int i=0; i<_n; i++)
samer@0 220 _GW[i][j] -= _W[i][j];
samer@0 221 }
samer@0 222
samer@0 223 public void flush()
samer@0 224 {
samer@0 225 if (batch==0) return;
samer@0 226 double eta=-rate.value/batch;
samer@0 227
samer@0 228 // now W += eta*GW
samer@0 229 for (int i=0; i<_n; i++) {
samer@0 230 double [] p = _W[i], q = _GW[i];
samer@0 231 for (int j=0; j<_n; j++) p[j] += eta*q[j];
samer@0 232 }
samer@0 233 GW.changed();
samer@0 234 W.changed(); GW.zero();
samer@0 235 batch=0;
samer@0 236 }
samer@0 237 }
samer@0 238 static boolean isActive(double s, double t) { return s>=t || s<=-t; }
samer@0 239
samer@0 240 /** This trainer saves on an O(N^2) step during accumulation, at
samer@0 241 the expense of an O(N^3) flush. As long as the batch size
samer@0 242 is O(N), then it should be about the same overall. The advantage
samer@0 243 is the collected statistics are more transparent, and can be used
samer@0 244 to make scalar or diagonal updates more frequenty. */
samer@0 245
samer@0 246
samer@0 247 public class ON3Trainer extends AnonymousTask implements Model.Trainer
samer@0 248 {
samer@0 249 Matrix G;
samer@0 250 double[][] _G, _W;
samer@0 251 double[] _s, buf;
samer@0 252 VDouble rate;
samer@0 253 int _n;
samer@0 254 double batch, thresh;
samer@0 255
samer@0 256 public ON3Trainer()
samer@0 257 {
samer@0 258 _n=n;
samer@0 259 G=new Matrix("G",n,n);
samer@0 260 rate=new VDouble("rate",0.01);
samer@0 261 thresh=Shell.getDouble("anomaly",20*n);
samer@0 262 batch=0;
samer@0 263
samer@0 264 _s=s.array();
samer@0 265 _G=G.getArray();
samer@0 266 _W=W.getArray();
samer@0 267 buf=new double[n]; // general purpose n-buffer
samer@0 268 }
samer@0 269
samer@0 270 public String toString() { return "ON3Trainer:"+ICA.this; }
samer@0 271
samer@0 272 /** this is so you can manipulate the matrix before flushing */
samer@0 273 public Matrix getGMatrix() { return G; }
samer@0 274
samer@0 275 public void starting() { reset(); }
samer@0 276 public void run() { accumulate(); }
samer@0 277
samer@0 278 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
samer@0 279 public void oneshot() { accumulate(1); flush(); }
samer@0 280 public void reset() { G.zero(); batch=0; }
samer@0 281 public void accumulate() { accumulate(1); }
samer@0 282 public void accumulate(double w)
samer@0 283 {
samer@0 284 // HACK!
samer@0 285 if (sourceModel.getEnergy()>thresh) return;
samer@0 286
samer@0 287 double q, p[];
samer@0 288 batch+=w;
samer@0 289
samer@0 290 double[] phi=sourceModel.getGradient();
samer@0 291 for (int i=0; i<_n; i++) {
samer@0 292 p = _G[i]; q=w*phi[i];
samer@0 293 // if (Double.isNaN(q)) throw new Error("NAN"+i);
samer@0 294 for (int j=0; j<_n; j++) p[j] += q*_s[j];
samer@0 295 p[i] -= w;
samer@0 296 }
samer@0 297 }
samer@0 298
samer@0 299 public void flush()
samer@0 300 {
samer@0 301 if (batch==0) return;
samer@0 302 double eta=-rate.value/batch;
samer@0 303
samer@0 304 G.changed();
samer@0 305
samer@0 306 // this is going to do a matrix G *= W, in place
samer@0 307 for (int i=0; i<_n; i++) {
samer@0 308 for (int j=0; j<_n; j++) {
samer@0 309 double a=0;
samer@0 310 for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
samer@0 311 buf[j] = a;
samer@0 312 }
samer@0 313 Mathx.copy(buf,_G[i]);
samer@0 314 }
samer@0 315
samer@0 316 // now W += eta*G
samer@0 317 for (int i=0; i<_n; i++) {
samer@0 318 double [] p = _W[i], q = _G[i];
samer@0 319 for (int j=0; j<n; j++) p[j] += eta*q[j];
samer@0 320 }
samer@0 321
samer@0 322 reset(); // ready for next batch
samer@0 323 }
samer@0 324 }
samer@0 325
samer@0 326 // See Hyvarinen's paper that Nick gave me. Not finished
samer@0 327 public class NewtonTrainer extends ON3Trainer
samer@0 328 {
samer@0 329 Function dgamma;
samer@0 330 VVector f; // scnd derivatives of log prior
samer@0 331 double[] _f;
samer@0 332
samer@0 333 public NewtonTrainer(Function dg)
samer@0 334 {
samer@0 335 dgamma=dg;
samer@0 336 f=new VVector("f",n);
samer@0 337 _f=f.array();
samer@0 338 }
samer@0 339
samer@0 340 public void dispose() { f.dispose(); dgamma.dispose(); super.dispose(); }
samer@0 341 public void reset() { Mathx.zero(_f); super.reset(); }
samer@0 342 public void accumulate(double w)
samer@0 343 {
samer@0 344 // HACK!
samer@0 345 if (sourceModel.getEnergy()>thresh) return;
samer@0 346
samer@0 347 dgamma.apply(_s,buf);
samer@0 348 Mathx.add(_f,buf);
samer@0 349 super.accumulate(w);
samer@0 350 }
samer@0 351
samer@0 352 public void flush()
samer@0 353 {
samer@0 354 if (batch==0) return;
samer@0 355
samer@0 356 // first do some things to G, then do normal flush
samer@0 357
samer@0 358 super.flush();
samer@0 359 }
samer@0 360 }
samer@0 361 }
samer@0 362