Mercurial > hg > jslab
view src/samer/models/VarianceICA.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
line wrap: on
line source
/* * Copyright (c) 2000, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.types.*; import samer.tools.*; import samer.maths.*; import samer.maths.*; import samer.maths.opt.*; public class VarianceICA extends NamedTask implements Model { Vec x; // data (input) int n, m; // sizes (data, sources) Matrix A; // basis matrix Model Ms; // source model VVector s, z, e; // sources, reconstruction, error VDouble E; // energy Task inference=new NullTask(); // ----- variables used in computations ----------------- private double [] e_; private double [] x_; private double [] z_; private double [] s_; private VectorFunctionOfVector tA, tAt; public VarianceICA(Node node, int inputs, int outputs) { super(node); Shell.push(node); n = inputs; m = outputs; s = new VVector("s",m); z = new VVector("z",n); e = new VVector("e",n); E = new VDouble("E"); A = new Matrix("A",n,m); A.identity(); e_ = e.array(); z_ = z.array(); s_ = s.array(); tA = new MatrixTimesVector(A); tAt= new MatrixTransposeTimesVector(A); Shell.pop(); } public Model getSourceModel() { return Ms; } public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); } public void setInput(Vec in) { x=in; x_ = x.array(); } public Matrix basisMatrix() { return A; } public VVector output() { return s; } public VVector error() { return e; } public int getSize() { return n; } public void setInferenceTask(Task t) { inference=t; } public void infer() { try { inference.run(); } catch (Exception ex) { throw new Error("inference failed"); } tA.apply(s_,z_); Mathx.sub(e_,x_,z_); Mathx.div(e_,z_); Mathx.div(e_,z_); e.changed(); z.changed(); s.changed(); } public void compute() { double EE=0; for (int i=0; i<n; i++) { double t=x_[i]/z_[i]; EE += t - Math.log(t); } E.set(EE - n + Ms.getEnergy()); // what about dE/dx? } public double getEnergy() { return E.value; } public double [] getGradient() { return null; } // this is wrong public void run() { infer(); } public void dispose() { s.dispose(); A.dispose(); e.dispose(); z.dispose(); E.dispose(); tA.dispose(); tAt.dispose(); super.dispose(); } public Functionx functionx() { return null; } public Functionx posterior() { // returns Functionx which evaluates E and dE/ds at current x return new Functionx() { Functionx fMs=Ms.functionx(); double [] ge=new double[n]; public void dispose() { fMs.dispose(); } public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); } public double evaluate( double [] s, double [] g) { double E=0; tA.apply(s,z_); // z=As for (int i=0; i<n; i++) { double t=x_[i]/z_[i]; g[i] = (t-1)/z_[i]; E += t - Math.log(t); } tAt.apply(g,ge); // ge=A'*g E+=fMs.evaluate(s,g); Mathx.sub(g,ge); return E; } }; } public Trainer learnHebbian() { Shell.push(node); try { return new MatrixTrainer(e_,A,s_); } finally { Shell.pop(); } } public Trainer learnLewickiSejnowski() { Shell.push(node); try { final double [] h =new double[n]; final double [] f =new double[m]; return new MatrixTrainer(h,A,s_) { public void accumulate(double w) { tA.apply(Ms.getGradient(),h); super.accumulate(w); } public void flush() { // flush with decay for (int j=0; j<this.m; j++) for (int i=0; i<this.n; i++) _T[i][j] -= count*_A[i][j]; super.flush(); } }; } finally { Shell.pop(); } } public Trainer learnDecayWhenActive() { Shell.push(node); try { final double [] h =new double[n]; return new MatrixTrainer(h,A,s_) { VDouble th=new VDouble("threshold",0.01); public void accumulate(double w) { tA.apply(Ms.getGradient(),h); super.accumulate(w); // decay when active part double thresh=th.value; for (int j=0; j<this.m; j++) if (isActive(s_[j],thresh)) for (int i=0; i<this.n; i++) _T[i][j] -= _A[i][j]; } }; } finally { Shell.pop(); } } static boolean isActive(double s, double t) { return s>t || s<-t; } }