view src/samer/models/MatrixTrainer.java @ 5:b67a33c44de7

Remove some crap, etc
author samer
date Fri, 05 Apr 2019 21:34:25 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2000, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY; 
 *	without even the implied warranty of MERCHANTABILITY or 
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package samer.models;
import  samer.core.*;
import  samer.core.types.*;
import  samer.maths.*;
// import  samer.tools.*;

/**
	Handles batched delta updates to a matrix.
	The default flush multiplies the accumulated
	delta by the learing rate and adds it to the
	matrix.
 */

public class MatrixTrainer implements Model.Trainer, DoubleModel
{
	protected VParameter	ratep;		// learning rate
	protected Matrix		A, T;			// A=target matrix, T=matrix of deltas
	protected double[][]	_A, _T;
	protected double		rate=1;
	protected int			n, m;
	protected double []  a, b;
	protected double	count=0;

	public MatrixTrainer(Vec left, Matrix A, Vec right) {
		// must have left.size()=n, right.size()=m
		this(left.array(),A,right.array());
	}

	public MatrixTrainer(double [] a, Matrix A, double [] b)
	{
		this.A = A;
		n = A.getRowDimension();
		m = A.getColumnDimension();
		_A = A.getArray();

		ratep = new VParameter("rate",this);
		T = new Matrix("deltas",n,m);
		_T = T.getArray();
		this.a=a; this.b=b;
	}

	public void set(double r) { rate=r; }
	public double get() { return rate; }
	public VParameter getRate() { return ratep; }

	protected void outerProduct(double w, double [] a, double [] b) {
		for (int i=0; i<n; i++) {
			double q=w*a[i], p[] = _T[i]; 
			for (int j=0; j<m; j++) p[j] += q*b[j];
		}
	}
	
	public void dispose() { T.dispose(); ratep.dispose(); }
	public void reset() { T.zero(); T.changed(); count=0; }
	public void oneshot() { accumulate(1); flush(); }
	public void accumulate() { accumulate(1); } 
	public void accumulate(double w) { count+=w; outerProduct(w,a,b); }
	
	public void flush() {
		if (count==0) return;
		double eta=rate/count;

		T.changed(); // to display accumulated stats
		for (int i=0; i<n; i++) {
			double [] p = _A[i], q = _T[i];
			for (int j=0; j<m; j++) p[j] += eta*q[j];
			Mathx.zero(q);
		}
		A.changed();
		count=0;
	}
}