Mercurial > hg > jslab
view src/samer/models/notyet/PCA.java @ 8:5e3cbbf173aa tip
Reorganise some more
author | samer |
---|---|
date | Fri, 05 Apr 2019 22:41:58 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2002, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.Agent.*; import samer.core.types.*; import samer.tools.*; import samer.maths.*; import samer.maths.random.*; public abstract class PCA extends NamedTask implements Model { int n; Vec x; VVector s; Matrix W; VDouble logA; double[] _g,_s; // cached dE/dx Model sourceModel; MatrixTimesVector infer; MatrixTransposeTimesVector grad; public PCA(Vec input) { this(input.size()); setInput(input); } public PCA(int N) { super("pca"); Shell.push(node); n = N; x = null; W = new Matrix("W",n,n); s = new VVector("s",n); logA = new VDouble( "log |A|",0.0,VDouble.SIGNAL); infer = null; Shell.pop(); grad = new MatrixTransposeTimesVector(W); _g = new double[n]; _s = s.array(); } public int getSize() { return n; } public VVector output() { return s; } public void setInput(Vec input) { if (input.size()!=n) throw new Error("Input vector is the wrong size"); x = input; infer = new MatrixTimesVector(s,W,x); } public void dispose() { W.dispose(); logA.dispose(); s.dispose(); infer.dispose(); grad.dispose(); super.dispose(); } public void infer() { infer.run(); s.changed(); } // compute s=Wx public void compute() { grad.apply(_g,_s); } public double getEnergy() { return logA.value + sourceModel.getEnergy(); } public double[] getGradient() { return _g; } public void starting() {} public void stopping() {} public void run() { infer(); } public Task getTrainingTask() { return new Trainer(); } public class Trainer extends NamedTask { Matrix G; double[][] _G, _W; double[] _g,_s, buf; VDouble rate; int batch, _n; public Trainer() { super("learn",PCA.this.getNode()); Shell.push(Trainer.this.node); _n=n; G=new Matrix("G",n,n); rate=new VDouble("rate",0.01); Shell.pop(); batch=0; _s=s.array(); _G=G.getArray(); _W=W.getArray(); buf=new double[n]; // general purpose n-buffer } public void dispose() { G.dispose(); rate.dispose(); super.dispose(); } public void starting() { G.zero(); G.changed(); batch=0; } public void run() { double q, p[]; batch++; double[] phi=sourceModel.getGradient(); for (int i=0; i<_n; i++) { p = _G[i]; q=phi[i]; for (int j=0; j<_n; j++) p[j] += q*_s[j]; p[i] -= 1; } } public final void updateG() { G.changed(); } public void flush() { double eta=-rate.value/batch; // this is going to do a matrix G *= W, in place for (int i=0; i<_n; i++) { for (int j=0; j<_n; j++) { double a=0; for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j]; buf[j] = a; } Mathx.copy(buf,_G[i]); } // now W += eta*G for (int i=0; i<_n; i++) { double [] p = _W[i], q = _G[i]; for (int j=0; j<n; j++) p[j] += eta*q[j]; } // reset for next batch W.changed(); G.zero(); batch=0; } } }