Mercurial > hg > jslab
view src/samer/models/AlignedGaussian.java @ 5:b67a33c44de7
Remove some crap, etc
author | samer |
---|---|
date | Fri, 05 Apr 2019 21:34:25 +0100 |
parents | bf79fb79ee13 |
children |
line wrap: on
line source
/* * Copyright (c) 2002, Samer Abdallah, King's College London. * All rights reserved. * * This software is provided AS iS and WITHOUT ANY WARRANTY; * without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. */ package samer.models; import samer.core.*; import samer.core.types.*; import samer.maths.*; import samer.maths.opt.*; import samer.tools.*; /** Differential scaler: scales and offsets each element of a vector independently aiming to match a given prior model. This is like ICA using a diagonal weight matrix, with a 'zero-mean' normalisation built in, though it doesn't actually the mean to do this, but some statistic based on the prior model. */ public class AlignedGaussian extends AnonymousTask implements Model { private int n; private Vec x; private VVector s; private VVector w; // vector of multipliers private VDouble E; private VDouble logA; double [] _x, _s, _g, _w, _m, phi; public AlignedGaussian( Vec input) { this(input.size()); setInput(input); } public AlignedGaussian( int N) { n = N; x = null; w = new VVector("w",n); w.addSaver(); s = new VVector("s",n); E = new VDouble("E"); logA = new VDouble("log|A|",0); _s = s.array(); _w = w.array(); _g = new double[n]; phi = null; reset(); } public int getSize() { return n; } public VVector output() { return s; } public VVector weights() { return w; } public void setInput(Vec in) { x=in; _x=x.array(); } public void reset() { reset(1.0); } public void reset(double k) { Mathx.setAll(_w,k); w.changed(); logA.set(-sumlog(_w)); } public String toString() { return "AlignedGaussian("+x+")"; } private static double sumlog(double [] x) { double S=0; for (int i=0; i<x.length; i++) S += Math.log(x[i]); return S; } public void dispose() { logA.dispose(); w.dispose(); s.dispose(); super.dispose(); } public void infer() { Mathx.mul(_s,_x,_w); s.changed(); } public void compute() { Mathx.mul(_g,_s,_w); E.set(0.5*Mathx.norm2(_s)+logA.value); } public double getEnergy() { return E.value; } public double [] getGradient() { return _g; } public Functionx functionx() { return new Functionx() { double [] s=new double[n]; public void dispose() {} public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); } public double evaluate(double [] x, double [] g) { Mathx.mul(s,x,_w); Mathx.mul(g,s,_w); return 0.5*Mathx.norm2(s); // +logA.value; } }; } public void starting() { logA.set(-sumlog(_w)); } public void stopping() {} public void run() { infer(); compute(); } }