samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11
|
samer@0
|
12 import samer.core.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.maths.*;
|
samer@0
|
15 import samer.maths.opt.*;
|
samer@0
|
16 import samer.tools.*;
|
samer@0
|
17
|
samer@0
|
18 /**
|
samer@0
|
19 Differential scaler: scales and offsets each element of a vector
|
samer@0
|
20 independently aiming to match a given prior model. This is
|
samer@0
|
21 like ICA using a diagonal weight matrix, with a 'zero-mean'
|
samer@0
|
22 normalisation built in, though it doesn't actually the mean to
|
samer@0
|
23 do this, but some statistic based on the prior model.
|
samer@0
|
24 */
|
samer@0
|
25
|
samer@0
|
26 public class DiffScaler extends AnonymousTask implements Model
|
samer@0
|
27 {
|
samer@0
|
28 private Model M; // models P(s)
|
samer@0
|
29 private int n;
|
samer@0
|
30 private Vec x;
|
samer@0
|
31 private VVector s;
|
samer@0
|
32 private VVector w; // vector of multipliers
|
samer@0
|
33 private VVector mu; // vector of offsets
|
samer@0
|
34 private VDouble logA;
|
samer@0
|
35
|
samer@0
|
36 double [] _x, _s, _g, _w, _m, phi;
|
samer@0
|
37
|
samer@0
|
38 public DiffScaler( Vec input, Model M) { this(input); setOutputModel(M); M.setInput(s); }
|
samer@0
|
39 public DiffScaler( Vec input) { this(input.size()); setInput(input); }
|
samer@0
|
40 public DiffScaler( int N)
|
samer@0
|
41 {
|
samer@0
|
42 n = N;
|
samer@0
|
43
|
samer@0
|
44 x = null;
|
samer@0
|
45 mu = new VVector("mu",n); mu.addSaver();
|
samer@0
|
46 w = new VVector("w",n); w.addSaver();
|
samer@0
|
47 s = new VVector("output",n);
|
samer@0
|
48 logA = new VDouble("log|A|");
|
samer@0
|
49
|
samer@0
|
50 _s = s.array();
|
samer@0
|
51 _w = w.array();
|
samer@0
|
52 _m = mu.array();
|
samer@0
|
53 _g = new double[n];
|
samer@0
|
54 phi = null;
|
samer@0
|
55 reset();
|
samer@0
|
56 }
|
samer@0
|
57
|
samer@0
|
58 public int getSize() { return n; }
|
samer@0
|
59 public VVector output() { return s; }
|
samer@0
|
60 public VVector weights() { return w; }
|
samer@0
|
61 public VVector offsets() { return mu; }
|
samer@0
|
62 public Model getOutputModel() { return M; }
|
samer@0
|
63 public void setOutputModel(Model m) { M=m; }
|
samer@0
|
64 public void setInput(Vec in) { x=in; _x=x.array(); }
|
samer@0
|
65 public void reset() { reset(1.0); }
|
samer@0
|
66 public void reset(double k) {
|
samer@0
|
67 Mathx.setAll(_w,k); w.changed();
|
samer@0
|
68 Mathx.setAll(_m,0); mu.changed();
|
samer@0
|
69 logA.set(-sumlog(_w));
|
samer@0
|
70 }
|
samer@0
|
71
|
samer@0
|
72 public String toString() { return "DiffScaler("+x+")"; }
|
samer@0
|
73 private static double sumlog(double [] x) {
|
samer@0
|
74 double S=0;
|
samer@0
|
75 for (int i=0; i<x.length; i++) S += Math.log(x[i]);
|
samer@0
|
76 return S;
|
samer@0
|
77 }
|
samer@0
|
78
|
samer@0
|
79 public void dispose()
|
samer@0
|
80 {
|
samer@0
|
81 logA.dispose();
|
samer@0
|
82 mu.dispose();
|
samer@0
|
83 w.dispose();
|
samer@0
|
84 s.dispose();
|
samer@0
|
85 super.dispose();
|
samer@0
|
86 }
|
samer@0
|
87
|
samer@0
|
88 public void infer() {
|
samer@0
|
89 for (int i=0; i<n; i++) _s[i] = _w[i]*(_x[i]-_m[i]);
|
samer@0
|
90 s.changed();
|
samer@0
|
91 }
|
samer@0
|
92
|
samer@0
|
93 public void compute() {
|
samer@0
|
94 Mathx.mul(_g,M.getGradient(),_w);
|
samer@0
|
95 }
|
samer@0
|
96
|
samer@0
|
97 public double getEnergy() { return M.getEnergy() + logA.value; }
|
samer@0
|
98 public double [] getGradient() { return _g; }
|
samer@0
|
99
|
samer@0
|
100 public Functionx functionx() {
|
samer@0
|
101 return new Functionx() {
|
samer@0
|
102 Functionx fM=M.functionx();
|
samer@0
|
103 double [] s=new double[n];
|
samer@0
|
104
|
samer@0
|
105 public void dispose() { fM.dispose(); }
|
samer@0
|
106 public void evaluate(Datum P) { P.f=evaluate(P.x,P.g); }
|
samer@0
|
107 public double evaluate(double [] x, double [] g) {
|
samer@0
|
108 for (int i=0; i<n; i++) s[i] = _w[i]*(x[i]-_m[i]);
|
samer@0
|
109 double E=fM.evaluate(s,g);
|
samer@0
|
110 Mathx.mul(g,_w);
|
samer@0
|
111 return E+logA.value;
|
samer@0
|
112 }
|
samer@0
|
113 };
|
samer@0
|
114 }
|
samer@0
|
115
|
samer@0
|
116 public void starting() { logA.set(-sumlog(_w)); }
|
samer@0
|
117 public void stopping() {}
|
samer@0
|
118 public void run() { infer(); }
|
samer@0
|
119
|
samer@0
|
120 public Model.Trainer getTrainer() { return new Trainer(); }
|
samer@0
|
121 public Model.Trainer getOffsetTrainer() { return new OffsetTrainer(); }
|
samer@0
|
122 public Model.Trainer getScaleTrainer() { return new ScaleTrainer(); }
|
samer@0
|
123 public Model.Trainer getTensionedTrainer() { return new TensionedTrainer(); }
|
samer@0
|
124
|
samer@0
|
125
|
samer@0
|
126 public class Trainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
127 {
|
samer@0
|
128 VDouble rate1=new VDouble("scaleRate",0.0001);
|
samer@0
|
129 VDouble rate2=new VDouble("offsetRate",0.000001);
|
samer@0
|
130 double[] G,H;
|
samer@0
|
131 double [] _s;
|
samer@0
|
132 double count; // n
|
samer@0
|
133
|
samer@0
|
134
|
samer@0
|
135 public Trainer() {
|
samer@0
|
136 // n=DiffScaler.this.n;
|
samer@0
|
137 _s = s.array();
|
samer@0
|
138 G = new double[n];
|
samer@0
|
139 H = new double[n];
|
samer@0
|
140 }
|
samer@0
|
141
|
samer@0
|
142 public String toString() { return "Trainer:"+DiffScaler.this; }
|
samer@0
|
143
|
samer@0
|
144 public VDouble getScaleRate() { return rate1; }
|
samer@0
|
145 public VDouble getOffsetRate() { return rate2; }
|
samer@0
|
146
|
samer@0
|
147 public void reset() { count=0; Mathx.zero(G); Mathx.zero(H); }
|
samer@0
|
148 public void accumulate() { accumulate(1); }
|
samer@0
|
149 public void accumulate(double w) {
|
samer@0
|
150 double [] phi=M.getGradient();
|
samer@0
|
151 for (int i=0; i<n; i++) {
|
samer@0
|
152 G[i] += w*(phi[i]*_s[i] - 1);
|
samer@0
|
153 H[i] += w*phi[i];
|
samer@0
|
154 }
|
samer@0
|
155 count+=w;
|
samer@0
|
156 }
|
samer@0
|
157
|
samer@0
|
158 public void flush() {
|
samer@0
|
159 if (count==0) return; // nothing to do
|
samer@0
|
160
|
samer@0
|
161 double eta1 = rate1.value/count;
|
samer@0
|
162 double eta2 = rate2.value/count;
|
samer@0
|
163
|
samer@0
|
164 for (int i=0; i<n; i++) {
|
samer@0
|
165 _m[i] += eta2*H[i]/_w[i];
|
samer@0
|
166 _w[i] *= Math.exp(-eta1*G[i]);
|
samer@0
|
167 }
|
samer@0
|
168 logA.value+=eta1*Mathx.sum(G);
|
samer@0
|
169 logA.changed();
|
samer@0
|
170 mu.changed();
|
samer@0
|
171 w.changed();
|
samer@0
|
172 reset();
|
samer@0
|
173 }
|
samer@0
|
174
|
samer@0
|
175 public void oneshot() { reset(); accumulate(); flush(); }
|
samer@0
|
176 public void dispose() { rate1.dispose(); rate2.dispose(); }
|
samer@0
|
177 public void starting() { reset(); }
|
samer@0
|
178 public void run() { accumulate(); flush(); }
|
samer@0
|
179 }
|
samer@0
|
180
|
samer@0
|
181 public class ScaleTrainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
182 {
|
samer@0
|
183 VDouble scale=new VDouble("scale",0.001);
|
samer@0
|
184 VDouble stretch=new VDouble("stretch",0.001/n);
|
samer@0
|
185 double thresh=Shell.getDouble("anomaly",20*n);
|
samer@0
|
186 double[] G, H, _s;
|
samer@0
|
187 double count; // n
|
samer@0
|
188
|
samer@0
|
189
|
samer@0
|
190 public ScaleTrainer() {
|
samer@0
|
191 _s = s.array();
|
samer@0
|
192 G = new double[n];
|
samer@0
|
193 // H = new double[n];
|
samer@0
|
194 }
|
samer@0
|
195
|
samer@0
|
196 public String toString() { return "ScaleTrainer:"+DiffScaler.this; }
|
samer@0
|
197
|
samer@0
|
198 public VDouble getScaleRate() { return scale; }
|
samer@0
|
199 public VDouble getStretchRate() { return stretch; }
|
samer@0
|
200
|
samer@0
|
201 public void reset() { count=0; Mathx.zero(G); }
|
samer@0
|
202 public void accumulate() { accumulate(1); }
|
samer@0
|
203 public void accumulate(double w) {
|
samer@0
|
204 if (M.getEnergy()>thresh) return;
|
samer@0
|
205 double [] phi=M.getGradient();
|
samer@0
|
206 for (int i=0; i<n; i++) {
|
samer@0
|
207 G[i] += w*(phi[i]*_s[i] - 1);
|
samer@0
|
208 // H[i] += w*phi[i];
|
samer@0
|
209 }
|
samer@0
|
210 count+=w;
|
samer@0
|
211 }
|
samer@0
|
212
|
samer@0
|
213 public void flush() {
|
samer@0
|
214 if (count==0) return; // nothing to do
|
samer@0
|
215
|
samer@0
|
216
|
samer@0
|
217 { // filter elements of G
|
samer@0
|
218 double beta = stretch.value/scale.value;
|
samer@0
|
219 double meanG = Mathx.sum(G)/n;
|
samer@0
|
220 for (int i=0; i<n; i++) {
|
samer@0
|
221 G[i]=meanG+beta*(G[i]-meanG);
|
samer@0
|
222 }
|
samer@0
|
223 }
|
samer@0
|
224
|
samer@0
|
225 double alpha = scale.value/count;
|
samer@0
|
226 for (int i=0; i<n; i++) {
|
samer@0
|
227 double tmp=Math.exp(-alpha*G[i]);
|
samer@0
|
228 if (Double.isNaN(tmp)) throw new Error("alt: NaN"+i);
|
samer@0
|
229 _w[i] *= tmp;
|
samer@0
|
230 }
|
samer@0
|
231 logA.value+=alpha*Mathx.sum(G);
|
samer@0
|
232 logA.changed();
|
samer@0
|
233 w.changed();
|
samer@0
|
234 reset();
|
samer@0
|
235 }
|
samer@0
|
236
|
samer@0
|
237 public void oneshot() { reset(); accumulate(); flush(); }
|
samer@0
|
238 public void dispose() { scale.dispose(); stretch.dispose(); }
|
samer@0
|
239 public void starting() { reset(); }
|
samer@0
|
240 public void run() { accumulate(); flush(); }
|
samer@0
|
241 }
|
samer@0
|
242
|
samer@0
|
243 public class OffsetTrainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
244 {
|
samer@0
|
245 VDouble rate2=new VDouble("offsetRate",0.000001);
|
samer@0
|
246 double[] H;
|
samer@0
|
247 double count; // n
|
samer@0
|
248
|
samer@0
|
249
|
samer@0
|
250 public OffsetTrainer() {
|
samer@0
|
251 // n=DiffScaler.this.n;
|
samer@0
|
252 H = new double[n];
|
samer@0
|
253 }
|
samer@0
|
254
|
samer@0
|
255 public String toString() { return "OffsetTrainer:"+DiffScaler.this; }
|
samer@0
|
256
|
samer@0
|
257 public VDouble getOffsetRate() { return rate2; }
|
samer@0
|
258
|
samer@0
|
259 public void reset() { count=0; Mathx.zero(H); }
|
samer@0
|
260 public void accumulate() { accumulate(1); }
|
samer@0
|
261 public void accumulate(double w) {
|
samer@0
|
262 double [] phi=M.getGradient();
|
samer@0
|
263 for (int i=0; i<n; i++) H[i] += w*phi[i];
|
samer@0
|
264 count+=w;
|
samer@0
|
265 }
|
samer@0
|
266
|
samer@0
|
267 public void flush() {
|
samer@0
|
268 if (count==0) return; // nothing to do
|
samer@0
|
269
|
samer@0
|
270 double eta2 = rate2.value/count;
|
samer@0
|
271
|
samer@0
|
272 for (int i=0; i<n; i++) {
|
samer@0
|
273 _m[i] += eta2*H[i]/_w[i];
|
samer@0
|
274 }
|
samer@0
|
275 mu.changed();
|
samer@0
|
276 reset();
|
samer@0
|
277 }
|
samer@0
|
278
|
samer@0
|
279 public void oneshot() { reset(); accumulate(1); flush(); }
|
samer@0
|
280 public void dispose() { rate2.dispose(); }
|
samer@0
|
281 public void starting() { reset(); }
|
samer@0
|
282 public void run() { accumulate(1); flush(); }
|
samer@0
|
283 }
|
samer@0
|
284
|
samer@0
|
285 public class TensionedTrainer extends Trainer
|
samer@0
|
286 {
|
samer@0
|
287 VDouble tension=new VDouble("tension",0.01);
|
samer@0
|
288 double lw[]=new double[n];
|
samer@0
|
289
|
samer@0
|
290 public TensionedTrainer() {}
|
samer@0
|
291
|
samer@0
|
292 public String toString() { return "TensionedTrainer:"+DiffScaler.this; }
|
samer@0
|
293
|
samer@0
|
294 public VDouble getTension() { return tension; }
|
samer@0
|
295
|
samer@0
|
296 public void flush() {
|
samer@0
|
297 double T=tension.value;
|
samer@0
|
298 int i;
|
samer@0
|
299
|
samer@0
|
300 // go through and modify G
|
samer@0
|
301 for (i=0; i<n; i++) { lw[i]=Math.log(_w[i]); }
|
samer@0
|
302 G[0] -= T*(lw[1]-lw[0]); i=n-1;
|
samer@0
|
303 G[i] -= T*(lw[i-1]-lw[i]); i--;
|
samer@0
|
304 for (; i>0; i--) {
|
samer@0
|
305 G[i] -= T*(lw[i-1] -2*lw[i]+lw[i+1]);
|
samer@0
|
306 }
|
samer@0
|
307
|
samer@0
|
308 // do the same for H?
|
samer@0
|
309 // may have to modify T.
|
samer@0
|
310 H[0] -= T*(_m[1]-_m[0]); i=n-1;
|
samer@0
|
311 H[i] -= T*(_m[i-1]-_m[i]); i--;
|
samer@0
|
312 for (; i>0; i--) {
|
samer@0
|
313 H[i] -= T*(_m[i-1] -2*_m[i]+_m[i+1]);
|
samer@0
|
314 }
|
samer@0
|
315
|
samer@0
|
316 super.flush();
|
samer@0
|
317 }
|
samer@0
|
318
|
samer@0
|
319 public void dispose() { tension.dispose(); super.dispose(); }
|
samer@0
|
320 }
|
samer@0
|
321 }
|
samer@0
|
322
|