Mercurial > hg > jslab
diff src/samer/models/MatrixTrainer.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/samer/models/MatrixTrainer.java Tue Jan 17 17:50:20 2012 +0000 @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2000, Samer Abdallah, King's College London. + * All rights reserved. + * + * This software is provided AS iS and WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. + */ + +package samer.models; +import samer.core.*; +import samer.core.types.*; +import samer.maths.*; +// import samer.tools.*; + +/** + Handles batched delta updates to a matrix. + The default flush multiplies the accumulated + delta by the learing rate and adds it to the + matrix. + */ + +public class MatrixTrainer implements Model.Trainer, DoubleModel +{ + protected VParameter ratep; // learning rate + protected Matrix A, T; // A=target matrix, T=matrix of deltas + protected double[][] _A, _T; + protected double rate=1; + protected int n, m; + protected double [] a, b; + protected double count=0; + + public MatrixTrainer(Vec left, Matrix A, Vec right) { + // must have left.size()=n, right.size()=m + this(left.array(),A,right.array()); + } + + public MatrixTrainer(double [] a, Matrix A, double [] b) + { + this.A = A; + n = A.getRowDimension(); + m = A.getColumnDimension(); + _A = A.getArray(); + + ratep = new VParameter("rate",this); + T = new Matrix("deltas",n,m); + _T = T.getArray(); + this.a=a; this.b=b; + } + + public void set(double r) { rate=r; } + public double get() { return rate; } + public VParameter getRate() { return ratep; } + + protected void outerProduct(double w, double [] a, double [] b) { + for (int i=0; i<n; i++) { + double q=w*a[i], p[] = _T[i]; + for (int j=0; j<m; j++) p[j] += q*b[j]; + } + } + + public void dispose() { T.dispose(); ratep.dispose(); } + public void reset() { T.zero(); T.changed(); count=0; } + public void oneshot() { accumulate(1); flush(); } + public void accumulate() { accumulate(1); } + public void accumulate(double w) { count+=w; outerProduct(w,a,b); } + + public void flush() { + if (count==0) return; + double eta=rate/count; + + T.changed(); // to display accumulated stats + for (int i=0; i<n; i++) { + double [] p = _A[i], q = _T[i]; + for (int j=0; j<m; j++) p[j] += eta*q[j]; + Mathx.zero(q); + } + A.changed(); + count=0; + } +}