samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2000, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.types.*;
|
samer@0
|
13 import samer.maths.*;
|
samer@0
|
14 // import samer.tools.*;
|
samer@0
|
15
|
samer@0
|
16 /**
|
samer@0
|
17 Handles batched delta updates to a matrix.
|
samer@0
|
18 The default flush multiplies the accumulated
|
samer@0
|
19 delta by the learing rate and adds it to the
|
samer@0
|
20 matrix.
|
samer@0
|
21 */
|
samer@0
|
22
|
samer@0
|
23 public class MatrixTrainer implements Model.Trainer, DoubleModel
|
samer@0
|
24 {
|
samer@0
|
25 protected VParameter ratep; // learning rate
|
samer@0
|
26 protected Matrix A, T; // A=target matrix, T=matrix of deltas
|
samer@0
|
27 protected double[][] _A, _T;
|
samer@0
|
28 protected double rate=1;
|
samer@0
|
29 protected int n, m;
|
samer@0
|
30 protected double [] a, b;
|
samer@0
|
31 protected double count=0;
|
samer@0
|
32
|
samer@0
|
33 public MatrixTrainer(Vec left, Matrix A, Vec right) {
|
samer@0
|
34 // must have left.size()=n, right.size()=m
|
samer@0
|
35 this(left.array(),A,right.array());
|
samer@0
|
36 }
|
samer@0
|
37
|
samer@0
|
38 public MatrixTrainer(double [] a, Matrix A, double [] b)
|
samer@0
|
39 {
|
samer@0
|
40 this.A = A;
|
samer@0
|
41 n = A.getRowDimension();
|
samer@0
|
42 m = A.getColumnDimension();
|
samer@0
|
43 _A = A.getArray();
|
samer@0
|
44
|
samer@0
|
45 ratep = new VParameter("rate",this);
|
samer@0
|
46 T = new Matrix("deltas",n,m);
|
samer@0
|
47 _T = T.getArray();
|
samer@0
|
48 this.a=a; this.b=b;
|
samer@0
|
49 }
|
samer@0
|
50
|
samer@0
|
51 public void set(double r) { rate=r; }
|
samer@0
|
52 public double get() { return rate; }
|
samer@0
|
53 public VParameter getRate() { return ratep; }
|
samer@0
|
54
|
samer@0
|
55 protected void outerProduct(double w, double [] a, double [] b) {
|
samer@0
|
56 for (int i=0; i<n; i++) {
|
samer@0
|
57 double q=w*a[i], p[] = _T[i];
|
samer@0
|
58 for (int j=0; j<m; j++) p[j] += q*b[j];
|
samer@0
|
59 }
|
samer@0
|
60 }
|
samer@0
|
61
|
samer@0
|
62 public void dispose() { T.dispose(); ratep.dispose(); }
|
samer@0
|
63 public void reset() { T.zero(); T.changed(); count=0; }
|
samer@0
|
64 public void oneshot() { accumulate(1); flush(); }
|
samer@0
|
65 public void accumulate() { accumulate(1); }
|
samer@0
|
66 public void accumulate(double w) { count+=w; outerProduct(w,a,b); }
|
samer@0
|
67
|
samer@0
|
68 public void flush() {
|
samer@0
|
69 if (count==0) return;
|
samer@0
|
70 double eta=rate/count;
|
samer@0
|
71
|
samer@0
|
72 T.changed(); // to display accumulated stats
|
samer@0
|
73 for (int i=0; i<n; i++) {
|
samer@0
|
74 double [] p = _A[i], q = _T[i];
|
samer@0
|
75 for (int j=0; j<m; j++) p[j] += eta*q[j];
|
samer@0
|
76 Mathx.zero(q);
|
samer@0
|
77 }
|
samer@0
|
78 A.changed();
|
samer@0
|
79 count=0;
|
samer@0
|
80 }
|
samer@0
|
81 }
|