Mercurial > hg > jslab
diff src/samer/models/NoisyICA.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/samer/models/NoisyICA.java Tue Jan 17 17:50:20 2012 +0000 @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2000, Samer Abdallah, King's College London. + * All rights reserved. + * + * This software is provided AS iS and WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. + */ + +package samer.models; +import samer.core.*; +import samer.core.types.*; +import samer.tools.*; +import samer.maths.*; +import samer.maths.*; +import samer.maths.opt.*; + +public class NoisyICA extends NamedTask implements Model +{ + Vec x; // data (input) + int n, m; // sizes (data, sources) + Matrix A; // basis matrix + Model Ms, Me; // source and noise models + VVector s, z, e; // sources, reconstruction, error + VDouble E; // energy + + Task inference=new NullTask(); + + + // ----- variables used in computations ----------------- + + private double [] e_; + private double [] x_; + private double [] z_; + private double [] s_; + private VectorFunctionOfVector tA, tAt; + + + public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); } + public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); } + public NoisyICA(Node node, int inputs, int outputs) + { + super(node); + Shell.push(node); + + n = inputs; + m = outputs; + + s = new VVector("s",m); + z = new VVector("z",n); + e = new VVector("e",n); + E = new VDouble("E"); + A = new Matrix("A",n,m); + A.identity(); + + e_ = e.array(); + z_ = z.array(); + s_ = s.array(); + tA = new MatrixTimesVector(A); + tAt= new MatrixTransposeTimesVector(A); + + Shell.pop(); + } + + public Model getSourceModel() { return Ms; } + public Model getNoiseModel() { return Me; } + public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); } + public void setNoiseModel(Model m) { Me=m; Me.setInput(e); } + public void setInput(Vec in) { x=in; x_ = x.array(); } + public Matrix basisMatrix() { return A; } + public VVector output() { return s; } + public VVector error() { return e; } + public VVector reconstruction() { return z; } + + public int getSize() { return n; } + + public void setInferenceTask(Task t) { inference=t; } + + public void infer() { + try { inference.run(); } + catch (Exception ex) { + Shell.trace("error: "+ex); + ex.printStackTrace(); + throw new Error("inference failed: "+ex); } + tA.apply(s_,z_); Mathx.sub(e_,x_,z_); + e.changed(); z.changed(); s.changed(); + } + + public void compute() { + E.set(Me.getEnergy() + Ms.getEnergy()); + // what about dE/dx? + } + public double getEnergy() { return E.value; } + public double [] getGradient() { return null; } // this is wrong + + /** get basis vector norms into given array */ + public void norms(double [] na) + { + double [][] M=A.getArray(); + + Mathx.zero(na); + for (int i=0; i<n; i++) { + double [] Mi=M[i]; + for (int j=0; j<m; j++) { + na[j] += Mi[j]*Mi[j]; + } + } + } + + public void run() { infer(); } + public void dispose() + { + s.dispose(); + A.dispose(); + e.dispose(); + z.dispose(); + E.dispose(); + tA.dispose(); + tAt.dispose(); + + super.dispose(); + } + + public Functionx functionx() { return null; } + + public Functionx posterior() { + // returns Functionx which evaluates E and dE/ds at current x + return new Functionx() { + Functionx fMs=Ms.functionx(); + Functionx fMe=Me.functionx(); + double [] e=new double[n]; + double [] ge=new double[n]; + double [] gs=new double[m]; + + public void dispose() { fMs.dispose(); fMe.dispose(); } + public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); } + public double evaluate( double [] s, double [] g) + { + tA.apply(s,z_); Mathx.sub(e,x_,z_); + double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e) + double Es=fMs.evaluate(s,g); Mathx.sub(g,gs); // g=gamma(s)-ge + return Es+Ee; + } + }; + } + + public Trainer learnHebbian() { return new MatrixTrainer(e_,A,s_); } + + public Trainer learnLewickiSejnowski() + { + final double [] h =new double[n]; + final double [] f =new double[m]; + return new MatrixTrainer(h,A,s_) { + + public void accumulate(double w) { + // tAt.apply(Me.getGradient(),f); tA.apply(f,h); + tA.apply(Ms.getGradient(),h); + super.accumulate(w); + } + public void flush() { // flush with decay + for (int j=0; j<this.m; j++) + for (int i=0; i<this.n; i++) + _T[i][j] -= count*_A[i][j]; + super.flush(); + } + }; + } + + public Trainer learnDecayWhenActive() + { + final double [] h =new double[n]; + return new MatrixTrainer(h,A,s_) { + VDouble th=new VDouble("threshold",0.01); + public void accumulate(double w) { + // perhaps would like to get info out of optimiser here + tA.apply(Ms.getGradient(),h); + super.accumulate(w); + + // decay when active part + double thresh=th.value; + for (int j=0; j<this.m; j++) + if (isActive(s_[j],thresh)) + for (int i=0; i<this.n; i++) + _T[i][j] -= _A[i][j]; + } + + public VDouble getThreshold() { return th; } + }; + } + static boolean isActive(double s, double t) { return s>t || s<-t; } +}