Mercurial > hg > jslab
comparison src/samer/models/Mixture.java @ 0:bf79fb79ee13
Initial Mercurial check in.
author | samer |
---|---|
date | Tue, 17 Jan 2012 17:50:20 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:bf79fb79ee13 |
---|---|
1 /* | |
2 * Copyright (c) 2002, Samer Abdallah, King's College London. | |
3 * All rights reserved. | |
4 * | |
5 * This software is provided AS iS and WITHOUT ANY WARRANTY; | |
6 * without even the implied warranty of MERCHANTABILITY or | |
7 * FITNESS FOR A PARTICULAR PURPOSE. | |
8 */ | |
9 | |
10 package samer.models; | |
11 | |
12 import samer.core.*; | |
13 import samer.core.types.*; | |
14 import samer.maths.*; | |
15 import samer.maths.opt.*; | |
16 import samer.tools.*; | |
17 | |
18 | |
19 public class Mixture extends NamedTask implements Model | |
20 { | |
21 private Model M[]; // models | |
22 private int n, m; // size of vector, num models | |
23 private Vec x; // input | |
24 private VVector w; // prior weights | |
25 private VVector s; // posterior | |
26 private int k; // MAP estimate | |
27 private VDouble Z; // Parition function, ie p(x) | |
28 private double[] _x,_s,_w,_g; | |
29 | |
30 public Mixture( Vec input, int m) { this(input.size(), m); setInput(input); } | |
31 public Mixture( int N, int L) | |
32 { | |
33 super("mixture"); | |
34 Shell.push(node); | |
35 | |
36 n = N; | |
37 m = L; | |
38 | |
39 x = null; | |
40 w = new VVector("prior",m); | |
41 s = new VVector("posterior",m); | |
42 Z = new VDouble("Z"); | |
43 M = new Model[m]; | |
44 Shell.pop(); | |
45 | |
46 _s=s.array(); | |
47 _w=w.array(); | |
48 _g=new double[n]; | |
49 Mathx.set(_w,new Constant(1.0/L)); | |
50 } | |
51 | |
52 public VVector prior() { return w; } | |
53 public VVector posterior() { return s; } | |
54 public void setModel(int i, Model m) { M[i]=m; } | |
55 public void setInput(Vec in) { x=in; _x=x.array(); } | |
56 public int getSize() { return n; } | |
57 | |
58 public void dispose() | |
59 { | |
60 s.dispose(); | |
61 w.dispose(); | |
62 Z.dispose(); | |
63 for (int i=0; i<m; i++) M[i].dispose(); | |
64 super.dispose(); | |
65 } | |
66 | |
67 public void infer() { | |
68 // get models to compute energies. | |
69 // for (int i=0; i<m; i++) { M[i].infer(); M[i].compute(); } | |
70 | |
71 // compute relative posterior | |
72 for (int i=0; i<m; i++) _s[i] = M[i].getEnergy(); // collect energies | |
73 double Emin=Mathx.min(_s); | |
74 for (int i=0; i<m; i++) _s[i] = _w[i]*Math.exp(Emin-_s[i]); | |
75 | |
76 // compute partition function, normalise posterior | |
77 Z.set(Mathx.sum(_s)); // compute parition function | |
78 k=Mathx.argmax(_s); // get MAP model | |
79 Mathx.mul(_s,1/Z.value); // normalise posterior | |
80 s.changed(); | |
81 } | |
82 | |
83 public void compute() | |
84 { | |
85 /* compute gradients weighted by posterior */ | |
86 Mathx.zero(_g); | |
87 for (int i=0; i<m; i++) { | |
88 double [] phi = M[i].getGradient(); | |
89 for (int j=0; j<n; j++) _g[j] += _s[i]*phi[j]; | |
90 } | |
91 } | |
92 | |
93 public double getEnergy() { return -Math.log(Z.value); } | |
94 public double [] getGradient() { return _g; } | |
95 | |
96 public Functionx functionx() { return null; } | |
97 | |
98 public void run() { infer(); } | |
99 | |
100 public Trainer getTrainer() { return new Trainer(); } | |
101 | |
102 public class Trainer implements Model.Trainer | |
103 { | |
104 Model.Trainer T[]; | |
105 VDouble rate; | |
106 VVector dw; | |
107 double batch, _dw[]; | |
108 | |
109 public Trainer() { | |
110 T=new Model.Trainer[m]; // should all be null | |
111 rate=new VDouble("rate",0.001); | |
112 dw=new VVector("dw",m); | |
113 _dw=dw.array(); | |
114 } | |
115 | |
116 public void setTrainer(int i,Model.Trainer t) { T[i]=t; } | |
117 public void dispose() { rate.dispose(); dw.dispose(); } | |
118 | |
119 public void accumulate() { accumulate(1.0); } | |
120 public void accumulate(double w) { | |
121 batch+=w; | |
122 for (int i=0;i<m; i++) { | |
123 if (T[i]!=null) T[i].accumulate(w*_s[i]); // sweet | |
124 } | |
125 | |
126 // now accumulate info about priors | |
127 Mathx.add(_dw,_s); | |
128 } | |
129 | |
130 public void oneshot() { accumulate(1.0); flush(); } | |
131 public void flush() { | |
132 for (int i=0; i<m; i++) if (T[i]!=null) T[i].flush(); | |
133 double lambda=Mathx.dot(_w,_dw)/Mathx.dot(_w,_w); | |
134 double nu=rate.value/batch; | |
135 | |
136 dw.changed(); | |
137 for (int i=0; i<m; i++) { | |
138 _w[i] *= Math.exp(nu*(_dw[i]-lambda*_w[i])); // update w | |
139 } | |
140 Mathx.zero(_dw); batch=0; | |
141 | |
142 // normalise | |
143 Mathx.mul(_w,1/Mathx.sum(_w)); | |
144 w.changed(); | |
145 } | |
146 public void reset() { | |
147 for (int i=0; i<m; i++) if (T[i]!=null) T[i].reset(); | |
148 Mathx.zero(_dw); | |
149 batch=0; | |
150 } | |
151 } | |
152 } | |
153 |