samer@0: /* samer@0: * Copyright (c) 2002, Samer Abdallah, King's College London. samer@0: * All rights reserved. samer@0: * samer@0: * This software is provided AS iS and WITHOUT ANY WARRANTY; samer@0: * without even the implied warranty of MERCHANTABILITY or samer@0: * FITNESS FOR A PARTICULAR PURPOSE. samer@0: */ samer@0: samer@0: package samer.models; samer@0: import samer.core.*; samer@0: import samer.core.Agent.*; samer@0: import samer.core.types.*; samer@0: import samer.tools.*; samer@0: import samer.maths.*; samer@0: import samer.maths.random.*; samer@0: samer@0: samer@0: samer@0: public abstract class PCA extends NamedTask implements Model samer@0: { samer@0: int n; samer@0: Vec x; samer@0: VVector s; samer@0: Matrix W; samer@0: VDouble logA; samer@0: double[] _g,_s; // cached dE/dx samer@0: Model sourceModel; samer@0: samer@0: MatrixTimesVector infer; samer@0: MatrixTransposeTimesVector grad; samer@0: samer@0: public PCA(Vec input) { this(input.size()); setInput(input); } samer@0: public PCA(int N) samer@0: { samer@0: super("pca"); samer@0: Shell.push(node); samer@0: samer@0: n = N; samer@0: samer@0: x = null; samer@0: W = new Matrix("W",n,n); samer@0: s = new VVector("s",n); samer@0: logA = new VDouble( "log |A|",0.0,VDouble.SIGNAL); samer@0: infer = null; samer@0: Shell.pop(); samer@0: samer@0: grad = new MatrixTransposeTimesVector(W); samer@0: _g = new double[n]; samer@0: _s = s.array(); samer@0: } samer@0: samer@0: public int getSize() { return n; } samer@0: public VVector output() { return s; } samer@0: public void setInput(Vec input) { samer@0: if (input.size()!=n) throw new Error("Input vector is the wrong size"); samer@0: x = input; samer@0: infer = new MatrixTimesVector(s,W,x); samer@0: } samer@0: samer@0: public void dispose() samer@0: { samer@0: W.dispose(); samer@0: logA.dispose(); samer@0: s.dispose(); samer@0: infer.dispose(); samer@0: grad.dispose(); samer@0: samer@0: super.dispose(); samer@0: } samer@0: samer@0: public void infer() { infer.run(); s.changed(); } // compute s=Wx samer@0: public void compute() { samer@0: grad.apply(_g,_s); samer@0: } samer@0: samer@0: public double getEnergy() { return logA.value + sourceModel.getEnergy(); } samer@0: public double[] getGradient() { return _g; } samer@0: samer@0: public void starting() {} samer@0: public void stopping() {} samer@0: public void run() { infer(); } samer@0: samer@0: public Task getTrainingTask() { return new Trainer(); } samer@0: samer@0: public class Trainer extends NamedTask samer@0: { samer@0: Matrix G; samer@0: double[][] _G, _W; samer@0: double[] _g,_s, buf; samer@0: VDouble rate; samer@0: int batch, _n; samer@0: samer@0: public Trainer() samer@0: { samer@0: super("learn",PCA.this.getNode()); samer@0: Shell.push(Trainer.this.node); samer@0: _n=n; samer@0: G=new Matrix("G",n,n); samer@0: rate=new VDouble("rate",0.01); samer@0: Shell.pop(); samer@0: batch=0; samer@0: samer@0: _s=s.array(); samer@0: _G=G.getArray(); samer@0: _W=W.getArray(); samer@0: buf=new double[n]; // general purpose n-buffer samer@0: } samer@0: samer@0: public void dispose() { G.dispose(); rate.dispose(); super.dispose(); } samer@0: public void starting() { G.zero(); G.changed(); batch=0; } samer@0: public void run() samer@0: { samer@0: double q, p[]; samer@0: batch++; samer@0: samer@0: double[] phi=sourceModel.getGradient(); samer@0: for (int i=0; i<_n; i++) { samer@0: p = _G[i]; q=phi[i]; samer@0: for (int j=0; j<_n; j++) p[j] += q*_s[j]; samer@0: p[i] -= 1; samer@0: } samer@0: } samer@0: samer@0: public final void updateG() { G.changed(); } samer@0: public void flush() samer@0: { samer@0: double eta=-rate.value/batch; samer@0: samer@0: // this is going to do a matrix G *= W, in place samer@0: for (int i=0; i<_n; i++) { samer@0: for (int j=0; j<_n; j++) { samer@0: double a=0; samer@0: for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j]; samer@0: buf[j] = a; samer@0: } samer@0: Mathx.copy(buf,_G[i]); samer@0: } samer@0: samer@0: // now W += eta*G samer@0: for (int i=0; i<_n; i++) { samer@0: double [] p = _W[i], q = _G[i]; samer@0: for (int j=0; j