diff src/samer/models/MatrixTrainer.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/samer/models/MatrixTrainer.java	Tue Jan 17 17:50:20 2012 +0000
@@ -0,0 +1,81 @@
+/*
+ *	Copyright (c) 2000, Samer Abdallah, King's College London.
+ *	All rights reserved.
+ *
+ *	This software is provided AS iS and WITHOUT ANY WARRANTY; 
+ *	without even the implied warranty of MERCHANTABILITY or 
+ *	FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+package samer.models;
+import  samer.core.*;
+import  samer.core.types.*;
+import  samer.maths.*;
+// import  samer.tools.*;
+
+/**
+	Handles batched delta updates to a matrix.
+	The default flush multiplies the accumulated
+	delta by the learing rate and adds it to the
+	matrix.
+ */
+
+public class MatrixTrainer implements Model.Trainer, DoubleModel
+{
+	protected VParameter	ratep;		// learning rate
+	protected Matrix		A, T;			// A=target matrix, T=matrix of deltas
+	protected double[][]	_A, _T;
+	protected double		rate=1;
+	protected int			n, m;
+	protected double []  a, b;
+	protected double	count=0;
+
+	public MatrixTrainer(Vec left, Matrix A, Vec right) {
+		// must have left.size()=n, right.size()=m
+		this(left.array(),A,right.array());
+	}
+
+	public MatrixTrainer(double [] a, Matrix A, double [] b)
+	{
+		this.A = A;
+		n = A.getRowDimension();
+		m = A.getColumnDimension();
+		_A = A.getArray();
+
+		ratep = new VParameter("rate",this);
+		T = new Matrix("deltas",n,m);
+		_T = T.getArray();
+		this.a=a; this.b=b;
+	}
+
+	public void set(double r) { rate=r; }
+	public double get() { return rate; }
+	public VParameter getRate() { return ratep; }
+
+	protected void outerProduct(double w, double [] a, double [] b) {
+		for (int i=0; i<n; i++) {
+			double q=w*a[i], p[] = _T[i]; 
+			for (int j=0; j<m; j++) p[j] += q*b[j];
+		}
+	}
+	
+	public void dispose() { T.dispose(); ratep.dispose(); }
+	public void reset() { T.zero(); T.changed(); count=0; }
+	public void oneshot() { accumulate(1); flush(); }
+	public void accumulate() { accumulate(1); } 
+	public void accumulate(double w) { count+=w; outerProduct(w,a,b); }
+	
+	public void flush() {
+		if (count==0) return;
+		double eta=rate/count;
+
+		T.changed(); // to display accumulated stats
+		for (int i=0; i<n; i++) {
+			double [] p = _A[i], q = _T[i];
+			for (int j=0; j<m; j++) p[j] += eta*q[j];
+			Mathx.zero(q);
+		}
+		A.changed();
+		count=0;
+	}
+}