annotate src/samer/models/MatrixTrainer.java @ 3:15b93db27c04

Get StreamSource to compile, update args for demo
author samer
date Fri, 05 Apr 2019 17:00:18 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2000, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.types.*;
samer@0 13 import samer.maths.*;
samer@0 14 // import samer.tools.*;
samer@0 15
samer@0 16 /**
samer@0 17 Handles batched delta updates to a matrix.
samer@0 18 The default flush multiplies the accumulated
samer@0 19 delta by the learing rate and adds it to the
samer@0 20 matrix.
samer@0 21 */
samer@0 22
samer@0 23 public class MatrixTrainer implements Model.Trainer, DoubleModel
samer@0 24 {
samer@0 25 protected VParameter ratep; // learning rate
samer@0 26 protected Matrix A, T; // A=target matrix, T=matrix of deltas
samer@0 27 protected double[][] _A, _T;
samer@0 28 protected double rate=1;
samer@0 29 protected int n, m;
samer@0 30 protected double [] a, b;
samer@0 31 protected double count=0;
samer@0 32
samer@0 33 public MatrixTrainer(Vec left, Matrix A, Vec right) {
samer@0 34 // must have left.size()=n, right.size()=m
samer@0 35 this(left.array(),A,right.array());
samer@0 36 }
samer@0 37
samer@0 38 public MatrixTrainer(double [] a, Matrix A, double [] b)
samer@0 39 {
samer@0 40 this.A = A;
samer@0 41 n = A.getRowDimension();
samer@0 42 m = A.getColumnDimension();
samer@0 43 _A = A.getArray();
samer@0 44
samer@0 45 ratep = new VParameter("rate",this);
samer@0 46 T = new Matrix("deltas",n,m);
samer@0 47 _T = T.getArray();
samer@0 48 this.a=a; this.b=b;
samer@0 49 }
samer@0 50
samer@0 51 public void set(double r) { rate=r; }
samer@0 52 public double get() { return rate; }
samer@0 53 public VParameter getRate() { return ratep; }
samer@0 54
samer@0 55 protected void outerProduct(double w, double [] a, double [] b) {
samer@0 56 for (int i=0; i<n; i++) {
samer@0 57 double q=w*a[i], p[] = _T[i];
samer@0 58 for (int j=0; j<m; j++) p[j] += q*b[j];
samer@0 59 }
samer@0 60 }
samer@0 61
samer@0 62 public void dispose() { T.dispose(); ratep.dispose(); }
samer@0 63 public void reset() { T.zero(); T.changed(); count=0; }
samer@0 64 public void oneshot() { accumulate(1); flush(); }
samer@0 65 public void accumulate() { accumulate(1); }
samer@0 66 public void accumulate(double w) { count+=w; outerProduct(w,a,b); }
samer@0 67
samer@0 68 public void flush() {
samer@0 69 if (count==0) return;
samer@0 70 double eta=rate/count;
samer@0 71
samer@0 72 T.changed(); // to display accumulated stats
samer@0 73 for (int i=0; i<n; i++) {
samer@0 74 double [] p = _A[i], q = _T[i];
samer@0 75 for (int j=0; j<m; j++) p[j] += eta*q[j];
samer@0 76 Mathx.zero(q);
samer@0 77 }
samer@0 78 A.changed();
samer@0 79 count=0;
samer@0 80 }
samer@0 81 }