samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11
|
samer@0
|
12 import samer.core.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.maths.*;
|
samer@0
|
15 import samer.maths.opt.*;
|
samer@0
|
16 import samer.tools.*;
|
samer@0
|
17
|
samer@0
|
18
|
samer@0
|
19 public class Mixture extends NamedTask implements Model
|
samer@0
|
20 {
|
samer@0
|
21 private Model M[]; // models
|
samer@0
|
22 private int n, m; // size of vector, num models
|
samer@0
|
23 private Vec x; // input
|
samer@0
|
24 private VVector w; // prior weights
|
samer@0
|
25 private VVector s; // posterior
|
samer@0
|
26 private int k; // MAP estimate
|
samer@0
|
27 private VDouble Z; // Parition function, ie p(x)
|
samer@0
|
28 private double[] _x,_s,_w,_g;
|
samer@0
|
29
|
samer@0
|
30 public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); }
|
samer@0
|
31 public Mixture( int N, int L)
|
samer@0
|
32 {
|
samer@0
|
33 super("mixture");
|
samer@0
|
34 Shell.push(node);
|
samer@0
|
35
|
samer@0
|
36 n = N;
|
samer@0
|
37 m = L;
|
samer@0
|
38
|
samer@0
|
39 x = null;
|
samer@0
|
40 w = new VVector("prior",m);
|
samer@0
|
41 s = new VVector("posterior",m);
|
samer@0
|
42 Z = new VDouble("Z");
|
samer@0
|
43 M = new Model[m];
|
samer@0
|
44 Shell.pop();
|
samer@0
|
45
|
samer@0
|
46 _s=s.array();
|
samer@0
|
47 _w=w.array();
|
samer@0
|
48 _g=new double[n];
|
samer@0
|
49 Mathx.set(_w,new Constant(1.0/L));
|
samer@0
|
50 }
|
samer@0
|
51
|
samer@0
|
52 public VVector prior() { return w; }
|
samer@0
|
53 public VVector posterior() { return s; }
|
samer@0
|
54 public void setModel(int i, Model m) { M[i]=m; }
|
samer@0
|
55 public void setInput(Vec in) { x=in; _x=x.array(); }
|
samer@0
|
56 public int getSize() { return n; }
|
samer@0
|
57
|
samer@0
|
58 public void dispose()
|
samer@0
|
59 {
|
samer@0
|
60 s.dispose();
|
samer@0
|
61 w.dispose();
|
samer@0
|
62 Z.dispose();
|
samer@0
|
63 for (int i=0; i<m; i++) M[i].dispose();
|
samer@0
|
64 super.dispose();
|
samer@0
|
65 }
|
samer@0
|
66
|
samer@0
|
67 public void infer() {
|
samer@0
|
68 // get models to compute energies.
|
samer@0
|
69 // for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); }
|
samer@0
|
70
|
samer@0
|
71 // compute relative posterior
|
samer@0
|
72 for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies
|
samer@0
|
73 double Emin=Mathx.min(_s);
|
samer@0
|
74 for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]);
|
samer@0
|
75
|
samer@0
|
76 // compute partition function, normalise posterior
|
samer@0
|
77 Z.set(Mathx.sum(_s)); // compute parition function
|
samer@0
|
78 k=Mathx.argmax(_s); // get MAP model
|
samer@0
|
79 Mathx.mul(_s,1/Z.value); // normalise posterior
|
samer@0
|
80 s.changed();
|
samer@0
|
81 }
|
samer@0
|
82
|
samer@0
|
83 public void compute()
|
samer@0
|
84 {
|
samer@0
|
85 /* compute gradients weighted by posterior */
|
samer@0
|
86 Mathx.zero(_g);
|
samer@0
|
87 for (int i=0; i<m; i++) {
|
samer@0
|
88 double [] phi = M[i].getGradient();
|
samer@0
|
89 for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j];
|
samer@0
|
90 }
|
samer@0
|
91 }
|
samer@0
|
92
|
samer@0
|
93 public double getEnergy() { return -Math.log(Z.value); }
|
samer@0
|
94 public double [] getGradient() { return _g; }
|
samer@0
|
95
|
samer@0
|
96 public Functionx functionx() { return null; }
|
samer@0
|
97
|
samer@0
|
98 public void run() { infer(); }
|
samer@0
|
99
|
samer@0
|
100 public Trainer getTrainer() { return new Trainer(); }
|
samer@0
|
101
|
samer@0
|
102 public class Trainer implements Model.Trainer
|
samer@0
|
103 {
|
samer@0
|
104 Model.Trainer T[];
|
samer@0
|
105 VDouble rate;
|
samer@0
|
106 VVector dw;
|
samer@0
|
107 double batch, _dw[];
|
samer@0
|
108
|
samer@0
|
109 public Trainer() {
|
samer@0
|
110 T=new Model.Trainer[m]; // should all be null
|
samer@0
|
111 rate=new VDouble("rate",0.001);
|
samer@0
|
112 dw=new VVector("dw",m);
|
samer@0
|
113 _dw=dw.array();
|
samer@0
|
114 }
|
samer@0
|
115
|
samer@0
|
116 public void setTrainer(int i,Model.Trainer t) { T[i]=t; }
|
samer@0
|
117 public void dispose() { rate.dispose(); dw.dispose(); }
|
samer@0
|
118
|
samer@0
|
119 public void accumulate() { accumulate(1.0); }
|
samer@0
|
120 public void accumulate(double w) {
|
samer@0
|
121 batch+=w;
|
samer@0
|
122 for (int i=0;i<m; i++) {
|
samer@0
|
123 if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet
|
samer@0
|
124 }
|
samer@0
|
125
|
samer@0
|
126 // now accumulate info about priors
|
samer@0
|
127 Mathx.add(_dw,_s);
|
samer@0
|
128 }
|
samer@0
|
129
|
samer@0
|
130 public void oneshot() { accumulate(1.0); flush(); }
|
samer@0
|
131 public void flush() {
|
samer@0
|
132 for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush();
|
samer@0
|
133 double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w);
|
samer@0
|
134 double nu=rate.value/batch;
|
samer@0
|
135
|
samer@0
|
136 dw.changed();
|
samer@0
|
137 for (int i=0; i<m; i++) {
|
samer@0
|
138 _w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w
|
samer@0
|
139 }
|
samer@0
|
140 Mathx.zero(_dw); batch=0;
|
samer@0
|
141
|
samer@0
|
142 // normalise
|
samer@0
|
143 Mathx.mul(_w,1/Mathx.sum(_w));
|
samer@0
|
144 w.changed();
|
samer@0
|
145 }
|
samer@0
|
146 public void reset() {
|
samer@0
|
147 for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset();
|
samer@0
|
148 Mathx.zero(_dw);
|
samer@0
|
149 batch=0;
|
samer@0
|
150 }
|
samer@0
|
151 }
|
samer@0
|
152 }
|
samer@0
|
153
|