annotate src/samer/models/SparseICA.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2002, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.Agent.*;
samer@0 13 import samer.core.types.*;
samer@0 14 import samer.tools.*;
samer@0 15 import samer.maths.*;
samer@0 16 import java.util.*;
samer@0 17
samer@0 18 public class SparseICA extends NamedTask implements Model, Agent
samer@0 19 {
samer@0 20 Model sourceModel;
samer@0 21 int n;
samer@0 22 Vec x;
samer@0 23 SparseMatrix W; // , A;
samer@0 24 VVector s; // state
samer@0 25 VDouble logA;
samer@0 26 double[] _g; // cached dE/dx
samer@0 27
samer@0 28 public SparseICA(Vec input) { this(input.size()); setInput(input); }
samer@0 29 public SparseICA(int N)
samer@0 30 {
samer@0 31 super("ica");
samer@0 32 Shell.push(node);
samer@0 33
samer@0 34 n = N;
samer@0 35
samer@0 36 x = null;
samer@0 37 W = new SparseMatrix("W");
samer@0 38 // A = new SparseMatrix("A",n,n);
samer@0 39 s = new VVector("s",n);
samer@0 40 logA = new VDouble( "log |A|");
samer@0 41 Shell.pop();
samer@0 42
samer@0 43 W.identity(); W.changed();
samer@0 44 _g = new double[n];
samer@0 45 }
samer@0 46
samer@0 47 public int getSize() { return n; }
samer@0 48 public VVector output() { return s; }
samer@0 49 public Model getOutputModel() { return sourceModel; }
samer@0 50 public void setOutputModel(Model m) { sourceModel=m; }
samer@0 51 public void setInput(Vec input) {
samer@0 52 if (input.size()!=n) throw new Error("Input vector is the wrong size");
samer@0 53 x = input;
samer@0 54 }
samer@0 55
samer@0 56 public SparseMatrix getWeightMatrix() { return W; }
samer@0 57 // public Matrix getBasisMatrix() { return A; }
samer@0 58
samer@0 59 public void dispose()
samer@0 60 {
samer@0 61 W.dispose(); // A.dispose();
samer@0 62 logA.dispose();
samer@0 63 s.dispose();
samer@0 64 // infer.dispose();
samer@0 65 // grad.dispose();
samer@0 66
samer@0 67 super.dispose();
samer@0 68 }
samer@0 69
samer@0 70 public void infer() { W.times(s.array(),x.array()); s.changed(); }
samer@0 71 public void compute() { W.transposeTimes(_g,sourceModel.getGradient()); }
samer@0 72 public double getEnergy() { return sourceModel.getEnergy()+logA.value; }
samer@0 73 public double[] getGradient() { return _g; }
samer@0 74 public samer.maths.opt.Functionx functionx() { return null; }
samer@0 75
samer@0 76 public void starting() { }
samer@0 77 public void stopping() {}
samer@0 78 public void run() { infer(); }
samer@0 79
samer@0 80 public void getCommands(Registry r) { r.add("basis").add("logdet"); }
samer@0 81 public void execute(String cmd, Environment env)
samer@0 82 {
samer@0 83 /*
samer@0 84 if (cmd.equals("basis")) {
samer@0 85 Shell.print("computing ICA basis...");
samer@0 86 // sparse matrix inverse
samer@0 87 // much easier if matrix is block decomposable!
samer@0 88 A.assign(W.inverse());
samer@0 89 A.changed();
samer@0 90 Shell.print("...done.");
samer@0 91 } else if (cmd.equals("logdet")) {
samer@0 92 Shell.print("computing SVD...");
samer@0 93 // double [] s=W.svd().getSingularValues();
samer@0 94 // Shell.print("...done.");
samer@0 95 // Mathx.log(s);
samer@0 96 // logA.set(-Mathx.sum(s));
samer@0 97 }
samer@0 98 */
samer@0 99 }
samer@0 100
samer@0 101 public Trainer getTrainer() { return new ON2Trainer(); }
samer@0 102 public Trainer getAltTrainer() { return new ON3Trainer(); }
samer@0 103
samer@0 104 /** This trainer uses an O(N^2) run step and an O(N^2) flush. */
samer@0 105
samer@0 106 public class ON2Trainer implements Model.Trainer
samer@0 107 {
samer@0 108 double[] buf;
samer@0 109 VDouble rate;
samer@0 110 SparseMatrix GW;
samer@0 111 int _n;
samer@0 112 double thresh;
samer@0 113 double batch;
samer@0 114
samer@0 115 public ON2Trainer()
samer@0 116 {
samer@0 117 _n=n;
samer@0 118 Shell.push(getNode());
samer@0 119 GW=new SparseMatrix("GW",W); // copy pattern from W
samer@0 120 rate=new VDouble("rate",0.01);
samer@0 121 thresh=Shell.getDouble("anomaly",n*20);
samer@0 122 Shell.pop();
samer@0 123 batch=0;
samer@0 124
samer@0 125 buf=new double[n]; // general purpose n-buffer
samer@0 126 }
samer@0 127
samer@0 128 public void dispose() { GW.dispose(); rate.dispose(); }
samer@0 129 public void reset() { GW.zero(); GW.changed(); batch=0; }
samer@0 130 public void oneshot() { accumulate(1); flush(); }
samer@0 131 public void accumulate() { accumulate(1); }
samer@0 132 public void accumulate(double w) {
samer@0 133 batch+=w;
samer@0 134
samer@0 135 W.transposeTimes(buf,s.array()); // buf = W'*s
samer@0 136 Mathx.mul(buf,w); // buf = w*W'*s
samer@0 137 GW.addOuterProduct(sourceModel.getGradient(),buf); // GW+=grad*buf'
samer@0 138 // could subtract w*W here
samer@0 139 }
samer@0 140
samer@0 141 public void flush() {
samer@0 142 if (batch==0) return;
samer@0 143 W.icaUpdate(GW,-rate.value,batch); // W += eta(GW/batch - W)
samer@0 144 W.changed();
samer@0 145 GW.zero();
samer@0 146 batch=0;
samer@0 147 }
samer@0 148 }
samer@0 149
samer@0 150 /** This trainer saves on an O(N^2) step during accumulation, at
samer@0 151 the expense of an O(N^3) flush. As long as the batch size
samer@0 152 is O(N), then it should be about the same overall. The advantage
samer@0 153 is the collected statistics are more transparent, and can be used
samer@0 154 to make scalar or diagonal updates more frequenty. */
samer@0 155
samer@0 156
samer@0 157 public class ON3Trainer extends AnonymousTask implements Model.Trainer
samer@0 158 {
samer@0 159 SparseMatrix G;
samer@0 160 double[] buf;
samer@0 161 VDouble rate;
samer@0 162 int _n;
samer@0 163 double batch, thresh;
samer@0 164
samer@0 165 public ON3Trainer()
samer@0 166 {
samer@0 167 _n=n;
samer@0 168 G=new SparseMatrix("G");
samer@0 169 // set up G so that it has the right links
samer@0 170 rate=new VDouble("rate",0.01);
samer@0 171 thresh=Shell.getDouble("anomaly",20*n);
samer@0 172 batch=0;
samer@0 173
samer@0 174 buf=new double[n]; // general purpose n-buffer
samer@0 175 }
samer@0 176
samer@0 177 /** this is so you can manipulate the matrix before flushing */
samer@0 178 public SparseMatrix getGMatrix() { return G; }
samer@0 179
samer@0 180 public void starting() { reset(); }
samer@0 181 public void run() { accumulate(); }
samer@0 182
samer@0 183 public void dispose() { G.dispose(); rate.dispose(); super.dispose(); }
samer@0 184 public void oneshot() { accumulate(1); flush(); }
samer@0 185 public void reset() { G.zero(); batch=0; }
samer@0 186 public void accumulate() { accumulate(1); }
samer@0 187 public void accumulate(double w) {
samer@0 188 double _s[]=s.array();
samer@0 189 batch+=w;
samer@0 190
samer@0 191 Mathx.copy(sourceModel.getGradient(),buf);
samer@0 192 Mathx.mul(buf,w);
samer@0 193 G.addOuterProduct(buf,_s);
samer@0 194 // subtract identity to prevent loss of precision?
samer@0 195 }
samer@0 196
samer@0 197 public void flush()
samer@0 198 {
samer@0 199 if (batch==0) return;
samer@0 200
samer@0 201 // compute deltas: dW=W - G*W/batch;
samer@0 202 // now W += eta*dW
samer@0 203 // double eta=rate.value;
samer@0 204 // for (each link) W[i][j] += eta*dW[i][j];
samer@0 205 reset(); // ready for next batch
samer@0 206 }
samer@0 207 }
samer@0 208
samer@0 209 // Initialisation methods
samer@0 210
samer@0 211 /** x is an N-dim vector.
samer@0 212 R is an N by N symmetric similarity matrix.
samer@0 213 M is number of unit pairs (butterflies) build into ICA model */
samer@0 214
samer@0 215 public void init1(Matrix R, int M) {
samer@0 216 double [][] _R=R.getArray();
samer@0 217 TreeSet edges=new TreeSet();
samer@0 218
samer@0 219 Shell.status("Building sorted edge list...");
samer@0 220 for (int i=0; i<n; i++) {
samer@0 221 for (int j=i+1; j<n; j++) {
samer@0 222 edges.add(new Edge(i,j,_R[i][j]));
samer@0 223 }
samer@0 224 }
samer@0 225 Shell.status("Edge list complete.");
samer@0 226 Shell.status("Adding edges to sparse matrix...");
samer@0 227
samer@0 228 W.allocate(n+2*M);
samer@0 229 for (int i=0; i<n; i++) W.addElement(i,i,1);
samer@0 230
samer@0 231 Iterator it=edges.iterator();
samer@0 232 for (int k=0; k<M; k++) {
samer@0 233 Edge e=(Edge)it.next();
samer@0 234 W.addElement(e.i,e.j,0);
samer@0 235 W.addElement(e.j,e.i,0);
samer@0 236 }
samer@0 237 Shell.status("Sparse matrix complete.");
samer@0 238 }
samer@0 239
samer@0 240 /** This version only builds disjoint butterfly pairs, ie only 2N edges per layer
samer@0 241 Need a version that builds bigger local modules. */
samer@0 242 public void init2(Matrix R) {
samer@0 243 double [][] _R=R.getArray();
samer@0 244 boolean [] flags=new boolean[n];
samer@0 245 TreeSet edges=new TreeSet();
samer@0 246
samer@0 247 Shell.status("Building sorted edge list...");
samer@0 248 for (int i=0; i<n; i++) {
samer@0 249 for (int j=i+1; j<n; j++) {
samer@0 250 edges.add(new Edge(i,j,_R[i][j]));
samer@0 251 }
samer@0 252 }
samer@0 253 Shell.status("Edge list complete.");
samer@0 254 Shell.status("Adding edges to sparse matrix...");
samer@0 255
samer@0 256 W.allocate(2*n);
samer@0 257 Iterator it=edges.iterator();
samer@0 258 for (int k=0; k<n; k+=2) {
samer@0 259 Edge e;
samer@0 260 do { e=(Edge)it.next(); } while (flags[e.i] || flags[e.j]);
samer@0 261 W.addElement(e.i,e.i,1);
samer@0 262 W.addElement(e.i,e.j,0);
samer@0 263 W.addElement(e.j,e.i,0);
samer@0 264 W.addElement(e.j,e.j,1);
samer@0 265 flags[e.i]=flags[e.j]=true;
samer@0 266 }
samer@0 267 Shell.status("Sparse matrix complete.");
samer@0 268 }
samer@0 269 }
samer@0 270
samer@0 271 class Edge implements Comparable {
samer@0 272 int i, j;
samer@0 273 double x;
samer@0 274
samer@0 275 public Edge(int i, int j, double x) { this.i=i; this.j=j; this.x=x; }
samer@0 276 public int compareTo(Object o) {
samer@0 277 Edge e=(Edge)o;
samer@0 278 // NB: REVERSE ordering on x
samer@0 279 if (x>e.x) return -1;
samer@0 280 else if (x<e.x) return 1;
samer@0 281 else if (i<e.i) return -1;
samer@0 282 else if (i>e.i) return 1;
samer@0 283 else if (j<e.j) return -1;
samer@0 284 else if (j>e.j) return 1;
samer@0 285 else return 0;
samer@0 286 }
samer@0 287 public String toString() {
samer@0 288 StringBuffer buf=new StringBuffer("(");
samer@0 289 buf.append(i);
samer@0 290 buf.append(",");
samer@0 291 buf.append(j);
samer@0 292 buf.append(":");
samer@0 293 buf.append(x);
samer@0 294 return buf.toString();
samer@0 295 }
samer@0 296 }
samer@0 297
samer@0 298