Mercurial > hg > jslab
view src/samer/models/MatrixTrainer.java @ 5:b67a33c44de7
Remove some crap, etc
author | samer |
---|---|
date | Fri, 05 Apr 2019 21:34:25 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2000, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.types.*; import samer.maths.*; // import samer.tools.*; /** Handles batched delta updates to a matrix. The default flush multiplies the accumulated delta by the learing rate and adds it to the matrix. */ public class MatrixTrainer implements Model.Trainer, DoubleModel { protected VParameter ratep; // learning rate protected Matrix A, T; // A=target matrix, T=matrix of deltas protected double[][] _A, _T; protected double rate=1; protected int n, m; protected double [] a, b; protected double count=0; public MatrixTrainer(Vec left, Matrix A, Vec right) { // must have left.size()=n, right.size()=m this(left.array(),A,right.array()); } public MatrixTrainer(double [] a, Matrix A, double [] b) { this.A = A; n = A.getRowDimension(); m = A.getColumnDimension(); _A = A.getArray(); ratep = new VParameter("rate",this); T = new Matrix("deltas",n,m); _T = T.getArray(); this.a=a; this.b=b; } public void set(double r) { rate=r; } public double get() { return rate; } public VParameter getRate() { return ratep; } protected void outerProduct(double w, double [] a, double [] b) { for (int i=0; i<n; i++) { double q=w*a[i], p[] = _T[i]; for (int j=0; j<m; j++) p[j] += q*b[j]; } } public void dispose() { T.dispose(); ratep.dispose(); } public void reset() { T.zero(); T.changed(); count=0; } public void oneshot() { accumulate(1); flush(); } public void accumulate() { accumulate(1); } public void accumulate(double w) { count+=w; outerProduct(w,a,b); } public void flush() { if (count==0) return; double eta=rate/count; T.changed(); // to display accumulated stats for (int i=0; i<n; i++) { double [] p = _A[i], q = _T[i]; for (int j=0; j<m; j++) p[j] += eta*q[j]; Mathx.zero(q); } A.changed(); count=0; } }