comparison src/samer/models/AlignedGaussian.java @ 0:bf79fb79ee13

Initial Mercurial check in.
author samer
date Tue, 17 Jan 2012 17:50:20 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:bf79fb79ee13
1 /*
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
3 * All rights reserved.
4 *
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
6 * without even the implied warranty of MERCHANTABILITY or
7 * FITNESS FOR A PARTICULAR PURPOSE.
8 */
9
10 package samer.models;
11
12 import samer.core.*;
13 import samer.core.types.*;
14 import samer.maths.*;
15 import samer.maths.opt.*;
16 import samer.tools.*;
17
18 /**
19 Differential scaler: scales and offsets each element of a vector
20 independently aiming to match a given prior model. This is
21 like ICA using a diagonal weight matrix, with a 'zero-mean'
22 normalisation built in, though it doesn't actually the mean to
23 do this, but some statistic based on the prior model.
24 */
25
26 public class AlignedGaussian extends AnonymousTask implements Model
27 {
28 private int n;
29 private Vec x;
30 private VVector s;
31 private VVector w; // vector of multipliers
32 private VDouble E;
33 private VDouble logA;
34
35 double [] _x, _s, _g, _w, _m, phi;
36
37 public AlignedGaussian( Vec input) { this(input.size()); setInput(input); }
38 public AlignedGaussian( int N)
39 {
40 n = N;
41
42 x = null;
43 w = new VVector("w",n); w.addSaver();
44 s = new VVector("s",n);
45 E = new VDouble("E");
46 logA = new VDouble("log|A|",0);
47
48
49 _s = s.array();
50 _w = w.array();
51 _g = new double[n];
52 phi = null;
53 reset();
54 }
55
56 public int getSize() { return n; }
57 public VVector output() { return s; }
58 public VVector weights() { return w; }
59 public void setInput(Vec in) { x=in; _x=x.array(); }
60 public void reset() { reset(1.0); }
61 public void reset(double k) {
62 Mathx.setAll(_w,k); w.changed();
63 logA.set(-sumlog(_w));
64 }
65
66 public String toString() { return "AlignedGaussian("+x+")"; }
67 private static double sumlog(double [] x) {
68 double S=0;
69 for (int i=0; i<x.length; i++) S += Math.log(x[i]);
70 return S;
71 }
72
73 public void dispose()
74 {
75 logA.dispose();
76 w.dispose();
77 s.dispose();
78 super.dispose();
79 }
80
81 public void infer() { Mathx.mul(_s,_x,_w); s.changed(); }
82 public void compute() {
83 Mathx.mul(_g,_s,_w);
84 E.set(0.5*Mathx.norm2(_s)+logA.value);
85 }
86
87 public double getEnergy() { return E.value; }
88 public double [] getGradient() { return _g; }
89
90 public Functionx functionx() {
91 return new Functionx() {
92 double [] s=new double[n];
93
94 public void dispose() {}
95 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
96 public double evaluate(double [] x, double [] g) {
97 Mathx.mul(s,x,_w);
98 Mathx.mul(g,s,_w);
99 return 0.5*Mathx.norm2(s); // +logA.value;
100 }
101 };
102 }
103
104 public void starting() { logA.set(-sumlog(_w)); }
105 public void stopping() {}
106 public void run() { infer(); compute(); }
107 }
108