Mercurial > hg > jslab
view src/samer/models/SparseICA.java @ 5:b67a33c44de7
Remove some crap, etc
author | samer |
---|---|
date | Fri, 05 Apr 2019 21:34:25 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2002, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.Agent.*; import samer.core.types.*; import samer.tools.*; import samer.maths.*; import java.util.*; public class SparseICA extends NamedTask implements Model, Agent { Model sourceModel; int n; Vec x; SparseMatrix W; // , A; VVector s; // state VDouble logA; double[] _g; // cached dE/dx public SparseICA(Vec input) { this(input.size()); setInput(input); } public SparseICA(int N) { super("ica"); Shell.push(node); n = N; x = null; W = new SparseMatrix("W"); // A = new SparseMatrix("A",n,n); s = new VVector("s",n); logA = new VDouble( "log |A|"); Shell.pop(); W.identity(); W.changed(); _g = new double[n]; } public int getSize() { return n; } public VVector output() { return s; } public Model getOutputModel() { return sourceModel; } public void setOutputModel(Model m) { sourceModel=m; } public void setInput(Vec input) { if (input.size()!=n) throw new Error("Input vector is the wrong size"); x = input; } public SparseMatrix getWeightMatrix() { return W; } // public Matrix getBasisMatrix() { return A; } public void dispose() { W.dispose(); // A.dispose(); logA.dispose(); s.dispose(); // infer.dispose(); // grad.dispose(); super.dispose(); } public void infer() { W.times(s.array(),x.array()); s.changed(); } public void compute() { W.transposeTimes(_g,sourceModel.getGradient()); } public double getEnergy() { return sourceModel.getEnergy()+logA.value; } public double[] getGradient() { return _g; } public samer.maths.opt.Functionx functionx() { return null; } public void starting() { } public void stopping() {} public void run() { infer(); } public void getCommands(Registry r) { r.add("basis").add("logdet"); } public void execute(String cmd, Environment env) { /* if (cmd.equals("basis")) { Shell.print("computing ICA basis..."); // sparse matrix inverse // much easier if matrix is block decomposable! A.assign(W.inverse()); A.changed(); Shell.print("...done."); } else if (cmd.equals("logdet")) { Shell.print("computing SVD..."); // double [] s=W.svd().getSingularValues(); // Shell.print("...done."); // Mathx.log(s); // logA.set(-Mathx.sum(s)); } */ } public Trainer getTrainer() { return new ON2Trainer(); } public Trainer getAltTrainer() { return new ON3Trainer(); } /** This trainer uses an O(N^2) run step and an O(N^2) flush. */ public class ON2Trainer implements Model.Trainer { double[] buf; VDouble rate; SparseMatrix GW; int _n; double thresh; double batch; public ON2Trainer() { _n=n; Shell.push(getNode()); GW=new SparseMatrix("GW",W); // copy pattern from W rate=new VDouble("rate",0.01); thresh=Shell.getDouble("anomaly",n*20); Shell.pop(); batch=0; buf=new double[n]; // general purpose n-buffer } public void dispose() { GW.dispose(); rate.dispose(); } public void reset() { GW.zero(); GW.changed(); batch=0; } public void oneshot() { accumulate(1); flush(); } public void accumulate() { accumulate(1); } public void accumulate(double w) { batch+=w; W.transposeTimes(buf,s.array()); // buf = W'*s Mathx.mul(buf,w); // buf = w*W'*s GW.addOuterProduct(sourceModel.getGradient(),buf); // GW+=grad*buf' // could subtract w*W here } public void flush() { if (batch==0) return; W.icaUpdate(GW,-rate.value,batch); // W += eta(GW/batch - W) W.changed(); GW.zero(); batch=0; } } /** This trainer saves on an O(N^2) step during accumulation, at the expense of an O(N^3) flush. As long as the batch size is O(N), then it should be about the same overall. The advantage is the collected statistics are more transparent, and can be used to make scalar or diagonal updates more frequenty. */ public class ON3Trainer extends AnonymousTask implements Model.Trainer { SparseMatrix G; double[] buf; VDouble rate; int _n; double batch, thresh; public ON3Trainer() { _n=n; G=new SparseMatrix("G"); // set up G so that it has the right links rate=new VDouble("rate",0.01); thresh=Shell.getDouble("anomaly",20*n); batch=0; buf=new double[n]; // general purpose n-buffer } /** this is so you can manipulate the matrix before flushing */ public SparseMatrix getGMatrix() { return G; } public void starting() { reset(); } public void run() { accumulate(); } public void dispose() { G.dispose(); rate.dispose(); super.dispose(); } public void oneshot() { accumulate(1); flush(); } public void reset() { G.zero(); batch=0; } public void accumulate() { accumulate(1); } public void accumulate(double w) { double _s[]=s.array(); batch+=w; Mathx.copy(sourceModel.getGradient(),buf); Mathx.mul(buf,w); G.addOuterProduct(buf,_s); // subtract identity to prevent loss of precision? } public void flush() { if (batch==0) return; // compute deltas: dW=W - G*W/batch; // now W += eta*dW // double eta=rate.value; // for (each link) W[i][j] += eta*dW[i][j]; reset(); // ready for next batch } } // Initialisation methods /** x is an N-dim vector. R is an N by N symmetric similarity matrix. M is number of unit pairs (butterflies) build into ICA model */ public void init1(Matrix R, int M) { double [][] _R=R.getArray(); TreeSet edges=new TreeSet(); Shell.status("Building sorted edge list..."); for (int i=0; i<n; i++) { for (int j=i+1; j<n; j++) { edges.add(new Edge(i,j,_R[i][j])); } } Shell.status("Edge list complete."); Shell.status("Adding edges to sparse matrix..."); W.allocate(n+2*M); for (int i=0; i<n; i++) W.addElement(i,i,1); Iterator it=edges.iterator(); for (int k=0; k<M; k++) { Edge e=(Edge)it.next(); W.addElement(e.i,e.j,0); W.addElement(e.j,e.i,0); } Shell.status("Sparse matrix complete."); } /** This version only builds disjoint butterfly pairs, ie only 2N edges per layer Need a version that builds bigger local modules. */ public void init2(Matrix R) { double [][] _R=R.getArray(); boolean [] flags=new boolean[n]; TreeSet edges=new TreeSet(); Shell.status("Building sorted edge list..."); for (int i=0; i<n; i++) { for (int j=i+1; j<n; j++) { edges.add(new Edge(i,j,_R[i][j])); } } Shell.status("Edge list complete."); Shell.status("Adding edges to sparse matrix..."); W.allocate(2*n); Iterator it=edges.iterator(); for (int k=0; k<n; k+=2) { Edge e; do { e=(Edge)it.next(); } while (flags[e.i] || flags[e.j]); W.addElement(e.i,e.i,1); W.addElement(e.i,e.j,0); W.addElement(e.j,e.i,0); W.addElement(e.j,e.j,1); flags[e.i]=flags[e.j]=true; } Shell.status("Sparse matrix complete."); } } class Edge implements Comparable { int i, j; double x; public Edge(int i, int j, double x) { this.i=i; this.j=j; this.x=x; } public int compareTo(Object o) { Edge e=(Edge)o; // NB: REVERSE ordering on x if (x>e.x) return -1; else if (x<e.x) return 1; else if (i<e.i) return -1; else if (i>e.i) return 1; else if (j<e.j) return -1; else if (j>e.j) return 1; else return 0; } public String toString() { StringBuffer buf=new StringBuffer("("); buf.append(i); buf.append(","); buf.append(j); buf.append(":"); buf.append(x); return buf.toString(); } }