annotate src/samer/models/notyet/PCA.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2002, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.Agent.*;
samer@0 13 import samer.core.types.*;
samer@0 14 import samer.tools.*;
samer@0 15 import samer.maths.*;
samer@0 16 import samer.maths.random.*;
samer@0 17
samer@0 18
samer@0 19
samer@0 20 public abstract class PCA extends NamedTask implements Model
samer@0 21 {
samer@0 22 int n;
samer@0 23 Vec x;
samer@0 24 VVector s;
samer@0 25 Matrix W;
samer@0 26 VDouble logA;
samer@0 27 double[] _g,_s; // cached dE/dx
samer@0 28 Model sourceModel;
samer@0 29
samer@0 30 MatrixTimesVector infer;
samer@0 31 MatrixTransposeTimesVector grad;
samer@0 32
samer@0 33 public PCA(Vec input) { this(input.size()); setInput(input); }
samer@0 34 public PCA(int N)
samer@0 35 {
samer@0 36 super("pca");
samer@0 37 Shell.push(node);
samer@0 38
samer@0 39 n = N;
samer@0 40
samer@0 41 x = null;
samer@0 42 W = new Matrix("W",n,n);
samer@0 43 s = new VVector("s",n);
samer@0 44 logA = new VDouble( "log |A|",0.0,VDouble.SIGNAL);
samer@0 45 infer = null;
samer@0 46 Shell.pop();
samer@0 47
samer@0 48 grad = new MatrixTransposeTimesVector(W);
samer@0 49 _g = new double[n];
samer@0 50 _s = s.array();
samer@0 51 }
samer@0 52
samer@0 53 public int getSize() { return n; }
samer@0 54 public VVector output() { return s; }
samer@0 55 public void setInput(Vec input) {
samer@0 56 if (input.size()!=n) throw new Error("Input vector is the wrong size");
samer@0 57 x = input;
samer@0 58 infer = new MatrixTimesVector(s,W,x);
samer@0 59 }
samer@0 60
samer@0 61 public void dispose()
samer@0 62 {
samer@0 63 W.dispose();
samer@0 64 logA.dispose();
samer@0 65 s.dispose();
samer@0 66 infer.dispose();
samer@0 67 grad.dispose();
samer@0 68
samer@0 69 super.dispose();
samer@0 70 }
samer@0 71
samer@0 72 public void infer() { infer.run(); s.changed(); } // compute s=Wx
samer@0 73 public void compute() {
samer@0 74 grad.apply(_g,_s);
samer@0 75 }
samer@0 76
samer@0 77 public double getEnergy() { return logA.value + sourceModel.getEnergy(); }
samer@0 78 public double[] getGradient() { return _g; }
samer@0 79
samer@0 80 public void starting() {}
samer@0 81 public void stopping() {}
samer@0 82 public void run() { infer(); }
samer@0 83
samer@0 84 public Task getTrainingTask() { return new Trainer(); }
samer@0 85
samer@0 86 public class Trainer extends NamedTask
samer@0 87 {
samer@0 88 Matrix G;
samer@0 89 double[][] _G, _W;
samer@0 90 double[] _g,_s, buf;
samer@0 91 VDouble rate;
samer@0 92 int batch, _n;
samer@0 93
samer@0 94 public Trainer()
samer@0 95 {
samer@0 96 super("learn",PCA.this.getNode());
samer@0 97 Shell.push(Trainer.this.node);
samer@0 98 _n=n;
samer@0 99 G=new Matrix("G",n,n);
samer@0 100 rate=new VDouble("rate",0.01);
samer@0 101 Shell.pop();
samer@0 102 batch=0;
samer@0 103
samer@0 104 _s=s.array();
samer@0 105 _G=G.getArray();
samer@0 106 _W=W.getArray();
samer@0 107 buf=new double[n]; // general purpose n-buffer
samer@0 108 }
samer@0 109
samer@0 110 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
samer@0 111 public void starting() { G.zero(); G.changed(); batch=0; }
samer@0 112 public void run()
samer@0 113 {
samer@0 114 double q, p[];
samer@0 115 batch++;
samer@0 116
samer@0 117 double[] phi=sourceModel.getGradient();
samer@0 118 for (int i=0; i<_n; i++) {
samer@0 119 p = _G[i]; q=phi[i];
samer@0 120 for (int j=0; j<_n; j++) p[j] += q*_s[j];
samer@0 121 p[i] -= 1;
samer@0 122 }
samer@0 123 }
samer@0 124
samer@0 125 public final void updateG() { G.changed(); }
samer@0 126 public void flush()
samer@0 127 {
samer@0 128 double eta=-rate.value/batch;
samer@0 129
samer@0 130 // this is going to do a matrix G *= W, in place
samer@0 131 for (int i=0; i<_n; i++) {
samer@0 132 for (int j=0; j<_n; j++) {
samer@0 133 double a=0;
samer@0 134 for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
samer@0 135 buf[j] = a;
samer@0 136 }
samer@0 137 Mathx.copy(buf,_G[i]);
samer@0 138 }
samer@0 139
samer@0 140 // now W += eta*G
samer@0 141 for (int i=0; i<_n; i++) {
samer@0 142 double [] p = _W[i], q = _G[i];
samer@0 143 for (int j=0; j<n; j++) p[j] += eta*q[j];
samer@0 144 }
samer@0 145
samer@0 146 // reset for next batch
samer@0 147 W.changed();
samer@0 148 G.zero();
samer@0 149 batch=0;
samer@0 150 }
samer@0 151 }
samer@0 152 }
samer@0 153