annotate src/samer/models/NoisyICA.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2000, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.types.*;
samer@0 13 import samer.tools.*;
samer@0 14 import samer.maths.*;
samer@0 15 import samer.maths.*;
samer@0 16 import samer.maths.opt.*;
samer@0 17
samer@0 18 public class NoisyICA extends NamedTask implements Model
samer@0 19 {
samer@0 20 Vec x; // data (input)
samer@0 21 int n, m; // sizes (data, sources)
samer@0 22 Matrix A; // basis matrix
samer@0 23 Model Ms, Me; // source and noise models
samer@0 24 VVector s, z, e; // sources, reconstruction, error
samer@0 25 VDouble E; // energy
samer@0 26
samer@0 27 Task inference=new NullTask();
samer@0 28
samer@0 29
samer@0 30 // ----- variables used in computations -----------------
samer@0 31
samer@0 32 private double [] e_;
samer@0 33 private double [] x_;
samer@0 34 private double [] z_;
samer@0 35 private double [] s_;
samer@0 36 private VectorFunctionOfVector tA, tAt;
samer@0 37
samer@0 38
samer@0 39 public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); }
samer@0 40 public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); }
samer@0 41 public NoisyICA(Node node, int inputs, int outputs)
samer@0 42 {
samer@0 43 super(node);
samer@0 44 Shell.push(node);
samer@0 45
samer@0 46 n = inputs;
samer@0 47 m = outputs;
samer@0 48
samer@0 49 s = new VVector("s",m);
samer@0 50 z = new VVector("z",n);
samer@0 51 e = new VVector("e",n);
samer@0 52 E = new VDouble("E");
samer@0 53 A = new Matrix("A",n,m);
samer@0 54 A.identity();
samer@0 55
samer@0 56 e_ = e.array();
samer@0 57 z_ = z.array();
samer@0 58 s_ = s.array();
samer@0 59 tA = new MatrixTimesVector(A);
samer@0 60 tAt= new MatrixTransposeTimesVector(A);
samer@0 61
samer@0 62 Shell.pop();
samer@0 63 }
samer@0 64
samer@0 65 public Model getSourceModel() { return Ms; }
samer@0 66 public Model getNoiseModel() { return Me; }
samer@0 67 public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
samer@0 68 public void setNoiseModel(Model m) { Me=m; Me.setInput(e); }
samer@0 69 public void setInput(Vec in) { x=in; x_ = x.array(); }
samer@0 70 public Matrix basisMatrix() { return A; }
samer@0 71 public VVector output() { return s; }
samer@0 72 public VVector error() { return e; }
samer@0 73 public VVector reconstruction() { return z; }
samer@0 74
samer@0 75 public int getSize() { return n; }
samer@0 76
samer@0 77 public void setInferenceTask(Task t) { inference=t; }
samer@0 78
samer@0 79 public void infer() {
samer@0 80 try { inference.run(); }
samer@0 81 catch (Exception ex) {
samer@0 82 Shell.trace("error: "+ex);
samer@0 83 ex.printStackTrace();
samer@0 84 throw new Error("inference failed: "+ex); }
samer@0 85 tA.apply(s_,z_); Mathx.sub(e_,x_,z_);
samer@0 86 e.changed(); z.changed(); s.changed();
samer@0 87 }
samer@0 88
samer@0 89 public void compute() {
samer@0 90 E.set(Me.getEnergy() + Ms.getEnergy());
samer@0 91 // what about dE/dx?
samer@0 92 }
samer@0 93 public double getEnergy() { return E.value; }
samer@0 94 public double [] getGradient() { return null; } // this is wrong
samer@0 95
samer@0 96 /** get basis vector norms into given array */
samer@0 97 public void norms(double [] na)
samer@0 98 {
samer@0 99 double [][] M=A.getArray();
samer@0 100
samer@0 101 Mathx.zero(na);
samer@0 102 for (int i=0; i<n; i++) {
samer@0 103 double [] Mi=M[i];
samer@0 104 for (int j=0; j<m; j++) {
samer@0 105 na[j] += Mi[j]*Mi[j];
samer@0 106 }
samer@0 107 }
samer@0 108 }
samer@0 109
samer@0 110 public void run() { infer(); }
samer@0 111 public void dispose()
samer@0 112 {
samer@0 113 s.dispose();
samer@0 114 A.dispose();
samer@0 115 e.dispose();
samer@0 116 z.dispose();
samer@0 117 E.dispose();
samer@0 118 tA.dispose();
samer@0 119 tAt.dispose();
samer@0 120
samer@0 121 super.dispose();
samer@0 122 }
samer@0 123
samer@0 124 public Functionx functionx() { return null; }
samer@0 125
samer@0 126 public Functionx posterior() {
samer@0 127 // returns Functionx which evaluates E and dE/ds at current x
samer@0 128 return new Functionx() {
samer@0 129 Functionx fMs=Ms.functionx();
samer@0 130 Functionx fMe=Me.functionx();
samer@0 131 double [] e=new double[n];
samer@0 132 double [] ge=new double[n];
samer@0 133 double [] gs=new double[m];
samer@0 134
samer@0 135 public void dispose() { fMs.dispose(); fMe.dispose(); }
samer@0 136 public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
samer@0 137 public double evaluate( double [] s, double [] g)
samer@0 138 {
samer@0 139 tA.apply(s,z_); Mathx.sub(e,x_,z_);
samer@0 140 double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e)
samer@0 141 double Es=fMs.evaluate(s,g); Mathx.sub(g,gs); // g=gamma(s)-ge
samer@0 142 return Es+Ee;
samer@0 143 }
samer@0 144 };
samer@0 145 }
samer@0 146
samer@0 147 public Trainer learnHebbian() { return new MatrixTrainer(e_,A,s_); }
samer@0 148
samer@0 149 public Trainer learnLewickiSejnowski()
samer@0 150 {
samer@0 151 final double [] h =new double[n];
samer@0 152 final double [] f =new double[m];
samer@0 153 return new MatrixTrainer(h,A,s_) {
samer@0 154
samer@0 155 public void accumulate(double w) {
samer@0 156 // tAt.apply(Me.getGradient(),f); tA.apply(f,h);
samer@0 157 tA.apply(Ms.getGradient(),h);
samer@0 158 super.accumulate(w);
samer@0 159 }
samer@0 160 public void flush() { // flush with decay
samer@0 161 for (int j=0; j<this.m; j++)
samer@0 162 for (int i=0; i<this.n; i++)
samer@0 163 _T[i][j] -= count*_A[i][j];
samer@0 164 super.flush();
samer@0 165 }
samer@0 166 };
samer@0 167 }
samer@0 168
samer@0 169 public Trainer learnDecayWhenActive()
samer@0 170 {
samer@0 171 final double [] h =new double[n];
samer@0 172 return new MatrixTrainer(h,A,s_) {
samer@0 173 VDouble th=new VDouble("threshold",0.01);
samer@0 174 public void accumulate(double w) {
samer@0 175 // perhaps would like to get info out of optimiser here
samer@0 176 tA.apply(Ms.getGradient(),h);
samer@0 177 super.accumulate(w);
samer@0 178
samer@0 179 // decay when active part
samer@0 180 double thresh=th.value;
samer@0 181 for (int j=0; j<this.m; j++)
samer@0 182 if (isActive(s_[j],thresh))
samer@0 183 for (int i=0; i<this.n; i++)
samer@0 184 _T[i][j] -= _A[i][j];
samer@0 185 }
samer@0 186
samer@0 187 public VDouble getThreshold() { return th; }
samer@0 188 };
samer@0 189 }
samer@0 190 static boolean isActive(double s, double t) { return s>t || s<-t; }
samer@0 191 }