comparison src/samer/models/MatrixTrainer.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:bf79fb79ee13
1 /*
2 * Copyright (c) 2000, Samer Abdallah, King's College London.
3 * All rights reserved.
4 *
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
6 * without even the implied warranty of MERCHANTABILITY or
7 * FITNESS FOR A PARTICULAR PURPOSE.
8 */
9
10 package samer.models;
11 import samer.core.*;
12 import samer.core.types.*;
13 import samer.maths.*;
14 // import samer.tools.*;
15
16 /**
17 Handles batched delta updates to a matrix.
18 The default flush multiplies the accumulated
19 delta by the learing rate and adds it to the
20 matrix.
21 */
22
23 public class MatrixTrainer implements Model.Trainer, DoubleModel
24 {
25 protected VParameter ratep; // learning rate
26 protected Matrix A, T; // A=target matrix, T=matrix of deltas
27 protected double[][] _A, _T;
28 protected double rate=1;
29 protected int n, m;
30 protected double [] a, b;
31 protected double count=0;
32
33 public MatrixTrainer(Vec left, Matrix A, Vec right) {
34 // must have left.size()=n, right.size()=m
35 this(left.array(),A,right.array());
36 }
37
38 public MatrixTrainer(double [] a, Matrix A, double [] b)
39 {
40 this.A = A;
41 n = A.getRowDimension();
42 m = A.getColumnDimension();
43 _A = A.getArray();
44
45 ratep = new VParameter("rate",this);
46 T = new Matrix("deltas",n,m);
47 _T = T.getArray();
48 this.a=a; this.b=b;
49 }
50
51 public void set(double r) { rate=r; }
52 public double get() { return rate; }
53 public VParameter getRate() { return ratep; }
54
55 protected void outerProduct(double w, double [] a, double [] b) {
56 for (int i=0; i<n; i++) {
57 double q=w*a[i], p[] = _T[i];
58 for (int j=0; j<m; j++) p[j] += q*b[j];
59 }
60 }
61
62 public void dispose() { T.dispose(); ratep.dispose(); }
63 public void reset() { T.zero(); T.changed(); count=0; }
64 public void oneshot() { accumulate(1); flush(); }
65 public void accumulate() { accumulate(1); }
66 public void accumulate(double w) { count+=w; outerProduct(w,a,b); }
67
68 public void flush() {
69 if (count==0) return;
70 double eta=rate/count;
71
72 T.changed(); // to display accumulated stats
73 for (int i=0; i<n; i++) {
74 double [] p = _A[i], q = _T[i];
75 for (int j=0; j<m; j++) p[j] += eta*q[j];
76 Mathx.zero(q);
77 }
78 A.changed();
79 count=0;
80 }
81 }