diff src/samer/models/AlignedGaussian.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/samer/models/AlignedGaussian.java	Tue Jan 17 17:50:20 2012 +0000
@@ -0,0 +1,108 @@
+/*
+ *	Copyright (c) 2002, Samer Abdallah, King's College London.
+ *	All rights reserved.
+ *
+ *	This software is provided AS iS and WITHOUT ANY WARRANTY;
+ *	without even the implied warranty of MERCHANTABILITY or
+ *	FITNESS FOR A PARTICULAR PURPOSE.
+ */
+
+package samer.models;
+
+import samer.core.*;
+import samer.core.types.*;
+import samer.maths.*;
+import samer.maths.opt.*;
+import samer.tools.*;
+
+/**
+	Differential scaler: scales and offsets each element of a vector
+	independently aiming to match a given prior model. This is
+	like ICA using a diagonal weight matrix, with a 'zero-mean'
+	normalisation built in, though it doesn't actually the mean to
+	do this, but some statistic based on the prior model.
+*/
+
+public class AlignedGaussian extends AnonymousTask implements Model
+{
+	private int			n;
+	private Vec			x;
+	private VVector	s;
+	private VVector	w;		// vector of multipliers
+	private VDouble	E;
+	private VDouble	logA;
+
+	double []		_x, _s, _g, _w, _m, phi;
+
+	public AlignedGaussian( Vec input) { this(input.size()); setInput(input); }
+	public AlignedGaussian( int N)
+	{
+		n = N;
+
+		x = null;
+		w = new VVector("w",n);		w.addSaver();
+		s = new VVector("s",n);
+		E = new VDouble("E");
+		logA = new VDouble("log|A|",0);
+		
+
+		_s = s.array();
+		_w = w.array();
+		_g = new double[n];
+		phi = null;
+		reset();
+	}
+
+	public int getSize() { return n; }
+	public VVector output() { return s; }
+	public VVector weights() { return w; }
+	public void setInput(Vec in) { x=in; _x=x.array(); }
+	public void reset() { reset(1.0); }
+	public void reset(double k) {
+		Mathx.setAll(_w,k); w.changed();
+		logA.set(-sumlog(_w));
+	}
+
+	public String toString() { return "AlignedGaussian("+x+")"; }
+	private static double sumlog(double [] x) {
+		double S=0;
+		for (int i=0; i<x.length; i++) S += Math.log(x[i]);
+		return S;
+	}
+
+	public void dispose()
+	{
+		logA.dispose();
+		w.dispose();
+		s.dispose();
+		super.dispose();
+	}
+
+	public void infer() { Mathx.mul(_s,_x,_w); s.changed(); }
+	public void compute() { 
+		Mathx.mul(_g,_s,_w); 
+		E.set(0.5*Mathx.norm2(_s)+logA.value);
+	}
+
+	public double	getEnergy() { return E.value; }
+	public double [] getGradient() { return _g; }
+
+	public Functionx functionx() {
+		return new Functionx() {
+			double [] s=new double[n];
+
+			public void dispose() {}
+			public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
+			public double evaluate(double [] x, double [] g) {
+				Mathx.mul(s,x,_w);
+				Mathx.mul(g,s,_w);
+				return 0.5*Mathx.norm2(s); // +logA.value;
+			}
+		};
+	}
+
+	public void starting() { logA.set(-sumlog(_w)); }
+	public void stopping() {}
+	public void run() { infer(); compute(); }
+}
+