Mercurial > hg > jslab
comparison src/samer/models/MatrixTrainer.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:bf79fb79ee13 |
---|---|
1 /* | |
2 * Copyright (c) 2000, Samer Abdallah, King's College London. | |
3 * All rights reserved. | |
4 * | |
5 * This software is provided AS iS and WITHOUT ANY WARRANTY; | |
6 * without even the implied warranty of MERCHANTABILITY or | |
7 * FITNESS FOR A PARTICULAR PURPOSE. | |
8 */ | |
9 | |
10 package samer.models; | |
11 import samer.core.*; | |
12 import samer.core.types.*; | |
13 import samer.maths.*; | |
14 // import samer.tools.*; | |
15 | |
16 /** | |
17 Handles batched delta updates to a matrix. | |
18 The default flush multiplies the accumulated | |
19 delta by the learing rate and adds it to the | |
20 matrix. | |
21 */ | |
22 | |
23 public class MatrixTrainer implements Model.Trainer, DoubleModel | |
24 { | |
25 protected VParameter ratep; // learning rate | |
26 protected Matrix A, T; // A=target matrix, T=matrix of deltas | |
27 protected double[][] _A, _T; | |
28 protected double rate=1; | |
29 protected int n, m; | |
30 protected double [] a, b; | |
31 protected double count=0; | |
32 | |
33 public MatrixTrainer(Vec left, Matrix A, Vec right) { | |
34 // must have left.size()=n, right.size()=m | |
35 this(left.array(),A,right.array()); | |
36 } | |
37 | |
38 public MatrixTrainer(double [] a, Matrix A, double [] b) | |
39 { | |
40 this.A = A; | |
41 n = A.getRowDimension(); | |
42 m = A.getColumnDimension(); | |
43 _A = A.getArray(); | |
44 | |
45 ratep = new VParameter("rate",this); | |
46 T = new Matrix("deltas",n,m); | |
47 _T = T.getArray(); | |
48 this.a=a; this.b=b; | |
49 } | |
50 | |
51 public void set(double r) { rate=r; } | |
52 public double get() { return rate; } | |
53 public VParameter getRate() { return ratep; } | |
54 | |
55 protected void outerProduct(double w, double [] a, double [] b) { | |
56 for (int i=0; i<n; i++) { | |
57 double q=w*a[i], p[] = _T[i]; | |
58 for (int j=0; j<m; j++) p[j] += q*b[j]; | |
59 } | |
60 } | |
61 | |
62 public void dispose() { T.dispose(); ratep.dispose(); } | |
63 public void reset() { T.zero(); T.changed(); count=0; } | |
64 public void oneshot() { accumulate(1); flush(); } | |
65 public void accumulate() { accumulate(1); } | |
66 public void accumulate(double w) { count+=w; outerProduct(w,a,b); } | |
67 | |
68 public void flush() { | |
69 if (count==0) return; | |
70 double eta=rate/count; | |
71 | |
72 T.changed(); // to display accumulated stats | |
73 for (int i=0; i<n; i++) { | |
74 double [] p = _A[i], q = _T[i]; | |
75 for (int j=0; j<m; j++) p[j] += eta*q[j]; | |
76 Mathx.zero(q); | |
77 } | |
78 A.changed(); | |
79 count=0; | |
80 } | |
81 } |