samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.Agent.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.tools.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16 import java.util.*;
|
samer@0
|
17
|
samer@0
|
18 public class SparseICA extends NamedTask implements Model, Agent
|
samer@0
|
19 {
|
samer@0
|
20 Model sourceModel;
|
samer@0
|
21 int n;
|
samer@0
|
22 Vec x;
|
samer@0
|
23 SparseMatrix W; // , A;
|
samer@0
|
24 VVector s; // state
|
samer@0
|
25 VDouble logA;
|
samer@0
|
26 double[] _g; // cached dE/dx
|
samer@0
|
27
|
samer@0
|
28 public SparseICA(Vec input) { this(input.size()); setInput(input); }
|
samer@0
|
29 public SparseICA(int N)
|
samer@0
|
30 {
|
samer@0
|
31 super("ica");
|
samer@0
|
32 Shell.push(node);
|
samer@0
|
33
|
samer@0
|
34 n = N;
|
samer@0
|
35
|
samer@0
|
36 x = null;
|
samer@0
|
37 W = new SparseMatrix("W");
|
samer@0
|
38 // A = new SparseMatrix("A",n,n);
|
samer@0
|
39 s = new VVector("s",n);
|
samer@0
|
40 logA = new VDouble( "log |A|");
|
samer@0
|
41 Shell.pop();
|
samer@0
|
42
|
samer@0
|
43 W.identity(); W.changed();
|
samer@0
|
44 _g = new double[n];
|
samer@0
|
45 }
|
samer@0
|
46
|
samer@0
|
47 public int getSize() { return n; }
|
samer@0
|
48 public VVector output() { return s; }
|
samer@0
|
49 public Model getOutputModel() { return sourceModel; }
|
samer@0
|
50 public void setOutputModel(Model m) { sourceModel=m; }
|
samer@0
|
51 public void setInput(Vec input) {
|
samer@0
|
52 if (input.size()!=n) throw new Error("Input vector is the wrong size");
|
samer@0
|
53 x = input;
|
samer@0
|
54 }
|
samer@0
|
55
|
samer@0
|
56 public SparseMatrix getWeightMatrix() { return W; }
|
samer@0
|
57 // public Matrix getBasisMatrix() { return A; }
|
samer@0
|
58
|
samer@0
|
59 public void dispose()
|
samer@0
|
60 {
|
samer@0
|
61 W.dispose(); // A.dispose();
|
samer@0
|
62 logA.dispose();
|
samer@0
|
63 s.dispose();
|
samer@0
|
64 // infer.dispose();
|
samer@0
|
65 // grad.dispose();
|
samer@0
|
66
|
samer@0
|
67 super.dispose();
|
samer@0
|
68 }
|
samer@0
|
69
|
samer@0
|
70 public void infer() { W.times(s.array(),x.array()); s.changed(); }
|
samer@0
|
71 public void compute() { W.transposeTimes(_g,sourceModel.getGradient()); }
|
samer@0
|
72 public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
|
samer@0
|
73 public double[] getGradient() { return _g; }
|
samer@0
|
74 public samer.maths.opt.Functionx functionx() { return null; }
|
samer@0
|
75
|
samer@0
|
76 public void starting() { }
|
samer@0
|
77 public void stopping() {}
|
samer@0
|
78 public void run() { infer(); }
|
samer@0
|
79
|
samer@0
|
80 public void getCommands(Registry r) { r.add("basis").add("logdet"); }
|
samer@0
|
81 public void execute(String cmd, Environment env)
|
samer@0
|
82 {
|
samer@0
|
83 /*
|
samer@0
|
84 if (cmd.equals("basis")) {
|
samer@0
|
85 Shell.print("computing ICA basis...");
|
samer@0
|
86 // sparse matrix inverse
|
samer@0
|
87 // much easier if matrix is block decomposable!
|
samer@0
|
88 A.assign(W.inverse());
|
samer@0
|
89 A.changed();
|
samer@0
|
90 Shell.print("...done.");
|
samer@0
|
91 } else if (cmd.equals("logdet")) {
|
samer@0
|
92 Shell.print("computing SVD...");
|
samer@0
|
93 // double [] s=W.svd().getSingularValues();
|
samer@0
|
94 // Shell.print("...done.");
|
samer@0
|
95 // Mathx.log(s);
|
samer@0
|
96 // logA.set(-Mathx.sum(s));
|
samer@0
|
97 }
|
samer@0
|
98 */
|
samer@0
|
99 }
|
samer@0
|
100
|
samer@0
|
101 public Trainer getTrainer() { return new ON2Trainer(); }
|
samer@0
|
102 public Trainer getAltTrainer() { return new ON3Trainer(); }
|
samer@0
|
103
|
samer@0
|
104 /** This trainer uses an O(N^2) run step and an O(N^2) flush. */
|
samer@0
|
105
|
samer@0
|
106 public class ON2Trainer implements Model.Trainer
|
samer@0
|
107 {
|
samer@0
|
108 double[] buf;
|
samer@0
|
109 VDouble rate;
|
samer@0
|
110 SparseMatrix GW;
|
samer@0
|
111 int _n;
|
samer@0
|
112 double thresh;
|
samer@0
|
113 double batch;
|
samer@0
|
114
|
samer@0
|
115 public ON2Trainer()
|
samer@0
|
116 {
|
samer@0
|
117 _n=n;
|
samer@0
|
118 Shell.push(getNode());
|
samer@0
|
119 GW=new SparseMatrix("GW",W); // copy pattern from W
|
samer@0
|
120 rate=new VDouble("rate",0.01);
|
samer@0
|
121 thresh=Shell.getDouble("anomaly",n*20);
|
samer@0
|
122 Shell.pop();
|
samer@0
|
123 batch=0;
|
samer@0
|
124
|
samer@0
|
125 buf=new double[n]; // general purpose n-buffer
|
samer@0
|
126 }
|
samer@0
|
127
|
samer@0
|
128 public void dispose() { GW.dispose(); rate.dispose(); }
|
samer@0
|
129 public void reset() { GW.zero(); GW.changed(); batch=0; }
|
samer@0
|
130 public void oneshot() { accumulate(1); flush(); }
|
samer@0
|
131 public void accumulate() { accumulate(1); }
|
samer@0
|
132 public void accumulate(double w) {
|
samer@0
|
133 batch+=w;
|
samer@0
|
134
|
samer@0
|
135 W.transposeTimes(buf,s.array()); // buf = W'*s
|
samer@0
|
136 Mathx.mul(buf,w); // buf = w*W'*s
|
samer@0
|
137 GW.addOuterProduct(sourceModel.getGradient(),buf); // GW+=grad*buf'
|
samer@0
|
138 // could subtract w*W here
|
samer@0
|
139 }
|
samer@0
|
140
|
samer@0
|
141 public void flush() {
|
samer@0
|
142 if (batch==0) return;
|
samer@0
|
143 W.icaUpdate(GW,-rate.value,batch); // W += eta(GW/batch - W)
|
samer@0
|
144 W.changed();
|
samer@0
|
145 GW.zero();
|
samer@0
|
146 batch=0;
|
samer@0
|
147 }
|
samer@0
|
148 }
|
samer@0
|
149
|
samer@0
|
150 /** This trainer saves on an O(N^2) step during accumulation, at
|
samer@0
|
151 the expense of an O(N^3) flush. As long as the batch size
|
samer@0
|
152 is O(N), then it should be about the same overall. The advantage
|
samer@0
|
153 is the collected statistics are more transparent, and can be used
|
samer@0
|
154 to make scalar or diagonal updates more frequenty. */
|
samer@0
|
155
|
samer@0
|
156
|
samer@0
|
157 public class ON3Trainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
158 {
|
samer@0
|
159 SparseMatrix G;
|
samer@0
|
160 double[] buf;
|
samer@0
|
161 VDouble rate;
|
samer@0
|
162 int _n;
|
samer@0
|
163 double batch, thresh;
|
samer@0
|
164
|
samer@0
|
165 public ON3Trainer()
|
samer@0
|
166 {
|
samer@0
|
167 _n=n;
|
samer@0
|
168 G=new SparseMatrix("G");
|
samer@0
|
169 // set up G so that it has the right links
|
samer@0
|
170 rate=new VDouble("rate",0.01);
|
samer@0
|
171 thresh=Shell.getDouble("anomaly",20*n);
|
samer@0
|
172 batch=0;
|
samer@0
|
173
|
samer@0
|
174 buf=new double[n]; // general purpose n-buffer
|
samer@0
|
175 }
|
samer@0
|
176
|
samer@0
|
177 /** this is so you can manipulate the matrix before flushing */
|
samer@0
|
178 public SparseMatrix getGMatrix() { return G; }
|
samer@0
|
179
|
samer@0
|
180 public void starting() { reset(); }
|
samer@0
|
181 public void run() { accumulate(); }
|
samer@0
|
182
|
samer@0
|
183 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
|
samer@0
|
184 public void oneshot() { accumulate(1); flush(); }
|
samer@0
|
185 public void reset() { G.zero(); batch=0; }
|
samer@0
|
186 public void accumulate() { accumulate(1); }
|
samer@0
|
187 public void accumulate(double w) {
|
samer@0
|
188 double _s[]=s.array();
|
samer@0
|
189 batch+=w;
|
samer@0
|
190
|
samer@0
|
191 Mathx.copy(sourceModel.getGradient(),buf);
|
samer@0
|
192 Mathx.mul(buf,w);
|
samer@0
|
193 G.addOuterProduct(buf,_s);
|
samer@0
|
194 // subtract identity to prevent loss of precision?
|
samer@0
|
195 }
|
samer@0
|
196
|
samer@0
|
197 public void flush()
|
samer@0
|
198 {
|
samer@0
|
199 if (batch==0) return;
|
samer@0
|
200
|
samer@0
|
201 // compute deltas: dW=W - G*W/batch;
|
samer@0
|
202 // now W += eta*dW
|
samer@0
|
203 // double eta=rate.value;
|
samer@0
|
204 // for (each link) W[i][j] += eta*dW[i][j];
|
samer@0
|
205 reset(); // ready for next batch
|
samer@0
|
206 }
|
samer@0
|
207 }
|
samer@0
|
208
|
samer@0
|
209 // Initialisation methods
|
samer@0
|
210
|
samer@0
|
211 /** x is an N-dim vector.
|
samer@0
|
212 R is an N by N symmetric similarity matrix.
|
samer@0
|
213 M is number of unit pairs (butterflies) build into ICA model */
|
samer@0
|
214
|
samer@0
|
215 public void init1(Matrix R, int M) {
|
samer@0
|
216 double [][] _R=R.getArray();
|
samer@0
|
217 TreeSet edges=new TreeSet();
|
samer@0
|
218
|
samer@0
|
219 Shell.status("Building sorted edge list...");
|
samer@0
|
220 for (int i=0; i<n; i++) {
|
samer@0
|
221 for (int j=i+1; j<n; j++) {
|
samer@0
|
222 edges.add(new Edge(i,j,_R[i][j]));
|
samer@0
|
223 }
|
samer@0
|
224 }
|
samer@0
|
225 Shell.status("Edge list complete.");
|
samer@0
|
226 Shell.status("Adding edges to sparse matrix...");
|
samer@0
|
227
|
samer@0
|
228 W.allocate(n+2*M);
|
samer@0
|
229 for (int i=0; i<n; i++) W.addElement(i,i,1);
|
samer@0
|
230
|
samer@0
|
231 Iterator it=edges.iterator();
|
samer@0
|
232 for (int k=0; k<M; k++) {
|
samer@0
|
233 Edge e=(Edge)it.next();
|
samer@0
|
234 W.addElement(e.i,e.j,0);
|
samer@0
|
235 W.addElement(e.j,e.i,0);
|
samer@0
|
236 }
|
samer@0
|
237 Shell.status("Sparse matrix complete.");
|
samer@0
|
238 }
|
samer@0
|
239
|
samer@0
|
240 /** This version only builds disjoint butterfly pairs, ie only 2N edges per layer
|
samer@0
|
241 Need a version that builds bigger local modules. */
|
samer@0
|
242 public void init2(Matrix R) {
|
samer@0
|
243 double [][] _R=R.getArray();
|
samer@0
|
244 boolean [] flags=new boolean[n];
|
samer@0
|
245 TreeSet edges=new TreeSet();
|
samer@0
|
246
|
samer@0
|
247 Shell.status("Building sorted edge list...");
|
samer@0
|
248 for (int i=0; i<n; i++) {
|
samer@0
|
249 for (int j=i+1; j<n; j++) {
|
samer@0
|
250 edges.add(new Edge(i,j,_R[i][j]));
|
samer@0
|
251 }
|
samer@0
|
252 }
|
samer@0
|
253 Shell.status("Edge list complete.");
|
samer@0
|
254 Shell.status("Adding edges to sparse matrix...");
|
samer@0
|
255
|
samer@0
|
256 W.allocate(2*n);
|
samer@0
|
257 Iterator it=edges.iterator();
|
samer@0
|
258 for (int k=0; k<n; k+=2) {
|
samer@0
|
259 Edge e;
|
samer@0
|
260 do { e=(Edge)it.next(); } while (flags[e.i] || flags[e.j]);
|
samer@0
|
261 W.addElement(e.i,e.i,1);
|
samer@0
|
262 W.addElement(e.i,e.j,0);
|
samer@0
|
263 W.addElement(e.j,e.i,0);
|
samer@0
|
264 W.addElement(e.j,e.j,1);
|
samer@0
|
265 flags[e.i]=flags[e.j]=true;
|
samer@0
|
266 }
|
samer@0
|
267 Shell.status("Sparse matrix complete.");
|
samer@0
|
268 }
|
samer@0
|
269 }
|
samer@0
|
270
|
samer@0
|
271 class Edge implements Comparable {
|
samer@0
|
272 int i, j;
|
samer@0
|
273 double x;
|
samer@0
|
274
|
samer@0
|
275 public Edge(int i, int j, double x) { this.i=i; this.j=j; this.x=x; }
|
samer@0
|
276 public int compareTo(Object o) {
|
samer@0
|
277 Edge e=(Edge)o;
|
samer@0
|
278 // NB: REVERSE ordering on x
|
samer@0
|
279 if (x>e.x) return -1;
|
samer@0
|
280 else if (x<e.x) return 1;
|
samer@0
|
281 else if (i<e.i) return -1;
|
samer@0
|
282 else if (i>e.i) return 1;
|
samer@0
|
283 else if (j<e.j) return -1;
|
samer@0
|
284 else if (j>e.j) return 1;
|
samer@0
|
285 else return 0;
|
samer@0
|
286 }
|
samer@0
|
287 public String toString() {
|
samer@0
|
288 StringBuffer buf=new StringBuffer("(");
|
samer@0
|
289 buf.append(i);
|
samer@0
|
290 buf.append(",");
|
samer@0
|
291 buf.append(j);
|
samer@0
|
292 buf.append(":");
|
samer@0
|
293 buf.append(x);
|
samer@0
|
294 return buf.toString();
|
samer@0
|
295 }
|
samer@0
|
296 }
|
samer@0
|
297
|
samer@0
|
298
|