samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.Agent.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.tools.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16 import samer.maths.random.*;
|
samer@0
|
17
|
samer@0
|
18
|
samer@0
|
19
|
samer@0
|
20 public abstract class PCA extends NamedTask implements Model
|
samer@0
|
21 {
|
samer@0
|
22 int n;
|
samer@0
|
23 Vec x;
|
samer@0
|
24 VVector s;
|
samer@0
|
25 Matrix W;
|
samer@0
|
26 VDouble logA;
|
samer@0
|
27 double[] _g,_s; // cached dE/dx
|
samer@0
|
28 Model sourceModel;
|
samer@0
|
29
|
samer@0
|
30 MatrixTimesVector infer;
|
samer@0
|
31 MatrixTransposeTimesVector grad;
|
samer@0
|
32
|
samer@0
|
33 public PCA(Vec input) { this(input.size()); setInput(input); }
|
samer@0
|
34 public PCA(int N)
|
samer@0
|
35 {
|
samer@0
|
36 super("pca");
|
samer@0
|
37 Shell.push(node);
|
samer@0
|
38
|
samer@0
|
39 n = N;
|
samer@0
|
40
|
samer@0
|
41 x = null;
|
samer@0
|
42 W = new Matrix("W",n,n);
|
samer@0
|
43 s = new VVector("s",n);
|
samer@0
|
44 logA = new VDouble( "log |A|",0.0,VDouble.SIGNAL);
|
samer@0
|
45 infer = null;
|
samer@0
|
46 Shell.pop();
|
samer@0
|
47
|
samer@0
|
48 grad = new MatrixTransposeTimesVector(W);
|
samer@0
|
49 _g = new double[n];
|
samer@0
|
50 _s = s.array();
|
samer@0
|
51 }
|
samer@0
|
52
|
samer@0
|
53 public int getSize() { return n; }
|
samer@0
|
54 public VVector output() { return s; }
|
samer@0
|
55 public void setInput(Vec input) {
|
samer@0
|
56 if (input.size()!=n) throw new Error("Input vector is the wrong size");
|
samer@0
|
57 x = input;
|
samer@0
|
58 infer = new MatrixTimesVector(s,W,x);
|
samer@0
|
59 }
|
samer@0
|
60
|
samer@0
|
61 public void dispose()
|
samer@0
|
62 {
|
samer@0
|
63 W.dispose();
|
samer@0
|
64 logA.dispose();
|
samer@0
|
65 s.dispose();
|
samer@0
|
66 infer.dispose();
|
samer@0
|
67 grad.dispose();
|
samer@0
|
68
|
samer@0
|
69 super.dispose();
|
samer@0
|
70 }
|
samer@0
|
71
|
samer@0
|
72 public void infer() { infer.run(); s.changed(); } // compute s=Wx
|
samer@0
|
73 public void compute() {
|
samer@0
|
74 grad.apply(_g,_s);
|
samer@0
|
75 }
|
samer@0
|
76
|
samer@0
|
77 public double getEnergy() { return logA.value + sourceModel.getEnergy(); }
|
samer@0
|
78 public double[] getGradient() { return _g; }
|
samer@0
|
79
|
samer@0
|
80 public void starting() {}
|
samer@0
|
81 public void stopping() {}
|
samer@0
|
82 public void run() { infer(); }
|
samer@0
|
83
|
samer@0
|
84 public Task getTrainingTask() { return new Trainer(); }
|
samer@0
|
85
|
samer@0
|
86 public class Trainer extends NamedTask
|
samer@0
|
87 {
|
samer@0
|
88 Matrix G;
|
samer@0
|
89 double[][] _G, _W;
|
samer@0
|
90 double[] _g,_s, buf;
|
samer@0
|
91 VDouble rate;
|
samer@0
|
92 int batch, _n;
|
samer@0
|
93
|
samer@0
|
94 public Trainer()
|
samer@0
|
95 {
|
samer@0
|
96 super("learn",PCA.this.getNode());
|
samer@0
|
97 Shell.push(Trainer.this.node);
|
samer@0
|
98 _n=n;
|
samer@0
|
99 G=new Matrix("G",n,n);
|
samer@0
|
100 rate=new VDouble("rate",0.01);
|
samer@0
|
101 Shell.pop();
|
samer@0
|
102 batch=0;
|
samer@0
|
103
|
samer@0
|
104 _s=s.array();
|
samer@0
|
105 _G=G.getArray();
|
samer@0
|
106 _W=W.getArray();
|
samer@0
|
107 buf=new double[n]; // general purpose n-buffer
|
samer@0
|
108 }
|
samer@0
|
109
|
samer@0
|
110 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
|
samer@0
|
111 public void starting() { G.zero(); G.changed(); batch=0; }
|
samer@0
|
112 public void run()
|
samer@0
|
113 {
|
samer@0
|
114 double q, p[];
|
samer@0
|
115 batch++;
|
samer@0
|
116
|
samer@0
|
117 double[] phi=sourceModel.getGradient();
|
samer@0
|
118 for (int i=0; i<_n; i++) {
|
samer@0
|
119 p = _G[i]; q=phi[i];
|
samer@0
|
120 for (int j=0; j<_n; j++) p[j] += q*_s[j];
|
samer@0
|
121 p[i] -= 1;
|
samer@0
|
122 }
|
samer@0
|
123 }
|
samer@0
|
124
|
samer@0
|
125 public final void updateG() { G.changed(); }
|
samer@0
|
126 public void flush()
|
samer@0
|
127 {
|
samer@0
|
128 double eta=-rate.value/batch;
|
samer@0
|
129
|
samer@0
|
130 // this is going to do a matrix G *= W, in place
|
samer@0
|
131 for (int i=0; i<_n; i++) {
|
samer@0
|
132 for (int j=0; j<_n; j++) {
|
samer@0
|
133 double a=0;
|
samer@0
|
134 for (int k=0; k<_n; k++) a += _G[i][k]*_W[k][j];
|
samer@0
|
135 buf[j] = a;
|
samer@0
|
136 }
|
samer@0
|
137 Mathx.copy(buf,_G[i]);
|
samer@0
|
138 }
|
samer@0
|
139
|
samer@0
|
140 // now W += eta*G
|
samer@0
|
141 for (int i=0; i<_n; i++) {
|
samer@0
|
142 double [] p = _W[i], q = _G[i];
|
samer@0
|
143 for (int j=0; j<n; j++) p[j] += eta*q[j];
|
samer@0
|
144 }
|
samer@0
|
145
|
samer@0
|
146 // reset for next batch
|
samer@0
|
147 W.changed();
|
samer@0
|
148 G.zero();
|
samer@0
|
149 batch=0;
|
samer@0
|
150 }
|
samer@0
|
151 }
|
samer@0
|
152 }
|
samer@0
|
153
|