Mercurial > hg > jslab
view src/samer/models/NoisyICA.java @ 8:5e3cbbf173aa tip
Reorganise some more
author | samer |
---|---|
date | Fri, 05 Apr 2019 22:41:58 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2000, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.types.*; import samer.tools.*; import samer.maths.*; import samer.maths.*; import samer.maths.opt.*; public class NoisyICA extends NamedTask implements Model { Vec x; // data (input) int n, m; // sizes (data, sources) Matrix A; // basis matrix Model Ms, Me; // source and noise models VVector s, z, e; // sources, reconstruction, error VDouble E; // energy Task inference=new NullTask(); // ----- variables used in computations ----------------- private double [] e_; private double [] x_; private double [] z_; private double [] s_; private VectorFunctionOfVector tA, tAt; public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); } public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); } public NoisyICA(Node node, int inputs, int outputs) { super(node); Shell.push(node); n = inputs; m = outputs; s = new VVector("s",m); z = new VVector("z",n); e = new VVector("e",n); E = new VDouble("E"); A = new Matrix("A",n,m); A.identity(); e_ = e.array(); z_ = z.array(); s_ = s.array(); tA = new MatrixTimesVector(A); tAt= new MatrixTransposeTimesVector(A); Shell.pop(); } public Model getSourceModel() { return Ms; } public Model getNoiseModel() { return Me; } public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); } public void setNoiseModel(Model m) { Me=m; Me.setInput(e); } public void setInput(Vec in) { x=in; x_ = x.array(); } public Matrix basisMatrix() { return A; } public VVector output() { return s; } public VVector error() { return e; } public VVector reconstruction() { return z; } public int getSize() { return n; } public void setInferenceTask(Task t) { inference=t; } public void infer() { try { inference.run(); } catch (Exception ex) { Shell.trace("error: "+ex); ex.printStackTrace(); throw new Error("inference failed: "+ex); } tA.apply(s_,z_); Mathx.sub(e_,x_,z_); e.changed(); z.changed(); s.changed(); } public void compute() { E.set(Me.getEnergy() + Ms.getEnergy()); // what about dE/dx? } public double getEnergy() { return E.value; } public double [] getGradient() { return null; } // this is wrong /** get basis vector norms into given array */ public void norms(double [] na) { double [][] M=A.getArray(); Mathx.zero(na); for (int i=0; i<n; i++) { double [] Mi=M[i]; for (int j=0; j<m; j++) { na[j] += Mi[j]*Mi[j]; } } } public void run() { infer(); } public void dispose() { s.dispose(); A.dispose(); e.dispose(); z.dispose(); E.dispose(); tA.dispose(); tAt.dispose(); super.dispose(); } public Functionx functionx() { return null; } public Functionx posterior() { // returns Functionx which evaluates E and dE/ds at current x return new Functionx() { Functionx fMs=Ms.functionx(); Functionx fMe=Me.functionx(); double [] e=new double[n]; double [] ge=new double[n]; double [] gs=new double[m]; public void dispose() { fMs.dispose(); fMe.dispose(); } public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); } public double evaluate( double [] s, double [] g) { tA.apply(s,z_); Mathx.sub(e,x_,z_); double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e) double Es=fMs.evaluate(s,g); Mathx.sub(g,gs); // g=gamma(s)-ge return Es+Ee; } }; } public Trainer learnHebbian() { return new MatrixTrainer(e_,A,s_); } public Trainer learnLewickiSejnowski() { final double [] h =new double[n]; final double [] f =new double[m]; return new MatrixTrainer(h,A,s_) { public void accumulate(double w) { // tAt.apply(Me.getGradient(),f); tA.apply(f,h); tA.apply(Ms.getGradient(),h); super.accumulate(w); } public void flush() { // flush with decay for (int j=0; j<this.m; j++) for (int i=0; i<this.n; i++) _T[i][j] -= count*_A[i][j]; super.flush(); } }; } public Trainer learnDecayWhenActive() { final double [] h =new double[n]; return new MatrixTrainer(h,A,s_) { VDouble th=new VDouble("threshold",0.01); public void accumulate(double w) { // perhaps would like to get info out of optimiser here tA.apply(Ms.getGradient(),h); super.accumulate(w); // decay when active part double thresh=th.value; for (int j=0; j<this.m; j++) if (isActive(s_[j],thresh)) for (int i=0; i<this.n; i++) _T[i][j] -= _A[i][j]; } public VDouble getThreshold() { return th; } }; } static boolean isActive(double s, double t) { return s>t || s<-t; } }