samer@0: /* samer@0: * Copyright (c) 2000, Samer Abdallah, King's College London. samer@0: * All rights reserved. samer@0: * samer@0: * This software is provided AS iS and WITHOUT ANY WARRANTY; samer@0: * without even the implied warranty of MERCHANTABILITY or samer@0: * FITNESS FOR A PARTICULAR PURPOSE. samer@0: */ samer@0: samer@0: package samer.models; samer@0: import samer.core.*; samer@0: import samer.core.types.*; samer@0: import samer.maths.*; samer@0: // import samer.tools.*; samer@0: samer@0: /** samer@0: Handles batched delta updates to a matrix. samer@0: The default flush multiplies the accumulated samer@0: delta by the learing rate and adds it to the samer@0: matrix. samer@0: */ samer@0: samer@0: public class MatrixTrainer implements Model.Trainer, DoubleModel samer@0: { samer@0: protected VParameter ratep; // learning rate samer@0: protected Matrix A, T; // A=target matrix, T=matrix of deltas samer@0: protected double[][] _A, _T; samer@0: protected double rate=1; samer@0: protected int n, m; samer@0: protected double [] a, b; samer@0: protected double count=0; samer@0: samer@0: public MatrixTrainer(Vec left, Matrix A, Vec right) { samer@0: // must have left.size()=n, right.size()=m samer@0: this(left.array(),A,right.array()); samer@0: } samer@0: samer@0: public MatrixTrainer(double [] a, Matrix A, double [] b) samer@0: { samer@0: this.A = A; samer@0: n = A.getRowDimension(); samer@0: m = A.getColumnDimension(); samer@0: _A = A.getArray(); samer@0: samer@0: ratep = new VParameter("rate",this); samer@0: T = new Matrix("deltas",n,m); samer@0: _T = T.getArray(); samer@0: this.a=a; this.b=b; samer@0: } samer@0: samer@0: public void set(double r) { rate=r; } samer@0: public double get() { return rate; } samer@0: public VParameter getRate() { return ratep; } samer@0: samer@0: protected void outerProduct(double w, double [] a, double [] b) { samer@0: for (int i=0; i