annotate src/samer/models/VarianceICA.java @ 3:15b93db27c04

Get StreamSource to compile, update args for demo
author samer
date Fri, 05 Apr 2019 17:00:18 +0100
parents bf79fb79ee13
children
rev   line source
samer@0 1 /*
samer@0 2 * Copyright (c) 2000, Samer Abdallah, King's College London.
samer@0 3 * All rights reserved.
samer@0 4 *
samer@0 5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
samer@0 6 * without even the implied warranty of MERCHANTABILITY or
samer@0 7 * FITNESS FOR A PARTICULAR PURPOSE.
samer@0 8 */
samer@0 9
samer@0 10 package samer.models;
samer@0 11 import samer.core.*;
samer@0 12 import samer.core.types.*;
samer@0 13 import samer.tools.*;
samer@0 14 import samer.maths.*;
samer@0 15 import samer.maths.*;
samer@0 16 import samer.maths.opt.*;
samer@0 17
samer@0 18 public class VarianceICA extends NamedTask implements Model
samer@0 19 {
samer@0 20 Vec x; // data (input)
samer@0 21 int n, m; // sizes (data, sources)
samer@0 22 Matrix A; // basis matrix
samer@0 23 Model Ms; // source model
samer@0 24 VVector s, z, e; // sources, reconstruction, error
samer@0 25 VDouble E; // energy
samer@0 26
samer@0 27 Task inference=new NullTask();
samer@0 28
samer@0 29
samer@0 30 // ----- variables used in computations -----------------
samer@0 31
samer@0 32 private double [] e_;
samer@0 33 private double [] x_;
samer@0 34 private double [] z_;
samer@0 35 private double [] s_;
samer@0 36 private VectorFunctionOfVector tA, tAt;
samer@0 37
samer@0 38
samer@0 39 public VarianceICA(Node node, int inputs, int outputs)
samer@0 40 {
samer@0 41 super(node);
samer@0 42 Shell.push(node);
samer@0 43
samer@0 44 n = inputs;
samer@0 45 m = outputs;
samer@0 46
samer@0 47 s = new VVector("s",m);
samer@0 48 z = new VVector("z",n);
samer@0 49 e = new VVector("e",n);
samer@0 50 E = new VDouble("E");
samer@0 51 A = new Matrix("A",n,m);
samer@0 52 A.identity();
samer@0 53
samer@0 54 e_ = e.array();
samer@0 55 z_ = z.array();
samer@0 56 s_ = s.array();
samer@0 57 tA = new MatrixTimesVector(A);
samer@0 58 tAt= new MatrixTransposeTimesVector(A);
samer@0 59
samer@0 60 Shell.pop();
samer@0 61 }
samer@0 62
samer@0 63 public Model getSourceModel() { return Ms; }
samer@0 64 public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
samer@0 65 public void setInput(Vec in) { x=in; x_ = x.array(); }
samer@0 66 public Matrix basisMatrix() { return A; }
samer@0 67 public VVector output() { return s; }
samer@0 68 public VVector error() { return e; }
samer@0 69
samer@0 70 public int getSize() { return n; }
samer@0 71
samer@0 72 public void setInferenceTask(Task t) { inference=t; }
samer@0 73
samer@0 74 public void infer() {
samer@0 75 try { inference.run(); }
samer@0 76 catch (Exception ex) { throw new Error("inference failed"); }
samer@0 77 tA.apply(s_,z_);
samer@0 78 Mathx.sub(e_,x_,z_);
samer@0 79 Mathx.div(e_,z_);
samer@0 80 Mathx.div(e_,z_);
samer@0 81 e.changed(); z.changed(); s.changed();
samer@0 82 }
samer@0 83
samer@0 84 public void compute() {
samer@0 85 double EE=0;
samer@0 86 for (int i=0; i<n; i++) {
samer@0 87 double t=x_[i]/z_[i];
samer@0 88 EE += t - Math.log(t);
samer@0 89 }
samer@0 90 E.set(EE - n + Ms.getEnergy());
samer@0 91 // what about dE/dx?
samer@0 92 }
samer@0 93 public double getEnergy() { return E.value; }
samer@0 94 public double [] getGradient() { return null; } // this is wrong
samer@0 95
samer@0 96 public void run() { infer(); }
samer@0 97 public void dispose()
samer@0 98 {
samer@0 99 s.dispose();
samer@0 100 A.dispose();
samer@0 101 e.dispose();
samer@0 102 z.dispose();
samer@0 103 E.dispose();
samer@0 104 tA.dispose();
samer@0 105 tAt.dispose();
samer@0 106
samer@0 107 super.dispose();
samer@0 108 }
samer@0 109
samer@0 110 public Functionx functionx() { return null; }
samer@0 111
samer@0 112 public Functionx posterior() {
samer@0 113 // returns Functionx which evaluates E and dE/ds at current x
samer@0 114 return new Functionx() {
samer@0 115 Functionx fMs=Ms.functionx();
samer@0 116 double [] ge=new double[n];
samer@0 117
samer@0 118 public void dispose() { fMs.dispose(); }
samer@0 119 public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
samer@0 120 public double evaluate( double [] s, double [] g)
samer@0 121 {
samer@0 122 double E=0;
samer@0 123
samer@0 124 tA.apply(s,z_); // z=As
samer@0 125 for (int i=0; i<n; i++) {
samer@0 126 double t=x_[i]/z_[i];
samer@0 127 g[i] = (t-1)/z_[i];
samer@0 128 E += t - Math.log(t);
samer@0 129 }
samer@0 130 tAt.apply(g,ge); // ge=A'*g
samer@0 131 E+=fMs.evaluate(s,g);
samer@0 132 Mathx.sub(g,ge);
samer@0 133 return E;
samer@0 134 }
samer@0 135 };
samer@0 136 }
samer@0 137
samer@0 138 public Trainer learnHebbian() {
samer@0 139 Shell.push(node);
samer@0 140 try { return new MatrixTrainer(e_,A,s_); }
samer@0 141 finally { Shell.pop(); }
samer@0 142 }
samer@0 143
samer@0 144 public Trainer learnLewickiSejnowski()
samer@0 145 {
samer@0 146 Shell.push(node);
samer@0 147 try {
samer@0 148 final double [] h =new double[n];
samer@0 149 final double [] f =new double[m];
samer@0 150 return new MatrixTrainer(h,A,s_) {
samer@0 151
samer@0 152 public void accumulate(double w) {
samer@0 153 tA.apply(Ms.getGradient(),h);
samer@0 154 super.accumulate(w);
samer@0 155 }
samer@0 156 public void flush() { // flush with decay
samer@0 157 for (int j=0; j<this.m; j++)
samer@0 158 for (int i=0; i<this.n; i++)
samer@0 159 _T[i][j] -= count*_A[i][j];
samer@0 160 super.flush();
samer@0 161 }
samer@0 162 };
samer@0 163 } finally { Shell.pop(); }
samer@0 164 }
samer@0 165
samer@0 166 public Trainer learnDecayWhenActive()
samer@0 167 {
samer@0 168 Shell.push(node);
samer@0 169 try {
samer@0 170 final double [] h =new double[n];
samer@0 171 return new MatrixTrainer(h,A,s_) {
samer@0 172 VDouble th=new VDouble("threshold",0.01);
samer@0 173 public void accumulate(double w) {
samer@0 174 tA.apply(Ms.getGradient(),h);
samer@0 175 super.accumulate(w);
samer@0 176
samer@0 177 // decay when active part
samer@0 178 double thresh=th.value;
samer@0 179 for (int j=0; j<this.m; j++)
samer@0 180 if (isActive(s_[j],thresh))
samer@0 181 for (int i=0; i<this.n; i++)
samer@0 182 _T[i][j] -= _A[i][j];
samer@0 183 }
samer@0 184 };
samer@0 185 } finally { Shell.pop(); }
samer@0 186 }
samer@0 187 static boolean isActive(double s, double t) { return s>t || s<-t; }
samer@0 188 }