samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2002, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.Agent.*;
|
samer@0
|
13 import samer.core.types.*;
|
samer@0
|
14 import samer.tools.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16
|
samer@0
|
17 public class ICAWithScaler extends ICA
|
samer@0
|
18 {
|
samer@0
|
19 VVector k; // state, scaling vector
|
samer@0
|
20 double[] __k, tmp;
|
samer@0
|
21
|
samer@0
|
22 public ICAWithScaler(Vec input) { this(input.size()); setInput(input); }
|
samer@0
|
23 public ICAWithScaler(int N)
|
samer@0
|
24 {
|
samer@0
|
25 super(N);
|
samer@0
|
26 Shell.push(node);
|
samer@0
|
27 k = new VVector("k",n);
|
samer@0
|
28 Shell.pop();
|
samer@0
|
29 __k=k.array();
|
samer@0
|
30 tmp=new double[n];
|
samer@0
|
31
|
samer@0
|
32 for (int i=0; i<n; i++) __k[i]=1;
|
samer@0
|
33 k.changed();
|
samer@0
|
34 }
|
samer@0
|
35
|
samer@0
|
36 public void dispose() { k.dispose(); super.dispose(); }
|
samer@0
|
37 public void infer() { // this overrides ICA.infer
|
samer@0
|
38 infer.run(); // compute s=Wx
|
samer@0
|
39 Mathx.mul(s.array(),__k);
|
samer@0
|
40 s.changed();
|
samer@0
|
41 }
|
samer@0
|
42
|
samer@0
|
43 public void compute() {
|
samer@0
|
44 Mathx.mul(tmp,sourceModel.getGradient(),__k);
|
samer@0
|
45 grad.apply(tmp,_g);
|
samer@0
|
46 }
|
samer@0
|
47
|
samer@0
|
48 public samer.maths.opt.Functionx functionx() { return null; }
|
samer@0
|
49
|
samer@0
|
50 public void fold() {
|
samer@0
|
51 for (int i=0; i<n; i++) {
|
samer@0
|
52 Mathx.mul(W.getArray()[i],__k[i]);
|
samer@0
|
53 __k[i]=1;
|
samer@0
|
54 }
|
samer@0
|
55 k.changed();
|
samer@0
|
56 W.changed();
|
samer@0
|
57 }
|
samer@0
|
58
|
samer@0
|
59 public Trainer getDiffTrainer() { return new DifferentialTrainer(); }
|
samer@0
|
60 public Trainer getScaleTrainer() { return new ScalerTrainer(); }
|
samer@0
|
61
|
samer@0
|
62
|
samer@0
|
63 public class DifferentialTrainer extends ON3Trainer {
|
samer@0
|
64 VDouble scaleRate,stretchRate;
|
samer@0
|
65 double lastflush;
|
samer@0
|
66
|
samer@0
|
67 public DifferentialTrainer() {
|
samer@0
|
68 scaleRate=new VDouble("alpha",0.1);
|
samer@0
|
69 stretchRate=new VDouble("beta",0.05);
|
samer@0
|
70 }
|
samer@0
|
71
|
samer@0
|
72 public void reset() { super.reset(); lastflush=0; }
|
samer@0
|
73 public void flush() { diffFlush(); fold(); super.flush(); lastflush=0; }
|
samer@0
|
74 public void diffFlush() // flush multipliers to k instead of W
|
samer@0
|
75 {
|
samer@0
|
76 double batchlet=batch-lastflush;
|
samer@0
|
77 if (batchlet==0) return;
|
samer@0
|
78
|
samer@0
|
79 // do differential learning on trace & diagonal of G
|
samer@0
|
80 double alpha=scaleRate.value/batchlet;
|
samer@0
|
81 double beta=stretchRate.value/batchlet;
|
samer@0
|
82
|
samer@0
|
83 // compute factors and scale each row of W
|
samer@0
|
84 double mu=G.trace()/n,dl;
|
samer@0
|
85 for (int i=0; i<_n; i++) {
|
samer@0
|
86 dl=alpha*mu+beta*(_G[i][i]-mu);
|
samer@0
|
87 double tmp=Math.exp(-dl);
|
samer@0
|
88 // if (Double.isNaN(tmp)) throw new Error("alt: NaN"+i);
|
samer@0
|
89 __k[i]*=tmp; // instead of Mathx.mul(_W[i],tmp);
|
samer@0
|
90 _G[i][i]=0;
|
samer@0
|
91 }
|
samer@0
|
92 k.changed();
|
samer@0
|
93 lastflush=batch;
|
samer@0
|
94 }
|
samer@0
|
95 }
|
samer@0
|
96
|
samer@0
|
97 /** This one trains ONLY the scaler part, not the ICA part, so is a lot faster
|
samer@0
|
98 than using the differential trainer with a zero learning rate. */
|
samer@0
|
99
|
samer@0
|
100 public class ScalerTrainer extends AnonymousTask implements Model.Trainer
|
samer@0
|
101 {
|
samer@0
|
102 VVector G;
|
samer@0
|
103 double[] _G;
|
samer@0
|
104 double[] _g,_s;
|
samer@0
|
105 VDouble scaleRate,stretchRate;
|
samer@0
|
106 int _n;
|
samer@0
|
107 double batch,thresh;
|
samer@0
|
108
|
samer@0
|
109 public ScalerTrainer()
|
samer@0
|
110 {
|
samer@0
|
111 _n=n;
|
samer@0
|
112 G=new VVector("G",n);
|
samer@0
|
113 thresh=Shell.getDouble("anomaly",20*n);
|
samer@0
|
114 scaleRate=new VDouble("alpha",0.02);
|
samer@0
|
115 stretchRate=new VDouble("beta",0.002);
|
samer@0
|
116 batch=0;
|
samer@0
|
117
|
samer@0
|
118 _s=s.array();
|
samer@0
|
119 _G=G.array();
|
samer@0
|
120 }
|
samer@0
|
121
|
samer@0
|
122 public void starting() { reset(); }
|
samer@0
|
123 public void run() { accumulate(); }
|
samer@0
|
124
|
samer@0
|
125 public void dispose() { G.dispose(); scaleRate.dispose(); stretchRate.dispose(); super.dispose(); }
|
samer@0
|
126 public void oneshot() { accumulate(); flush(); }
|
samer@0
|
127 public void reset() { Mathx.zero(_G); batch=0; }
|
samer@0
|
128 public void accumulate() { accumulate(1); }
|
samer@0
|
129 public void accumulate(double w) {
|
samer@0
|
130 // HACK!
|
samer@0
|
131 if (sourceModel.getEnergy()>thresh) return;
|
samer@0
|
132 batch+=w;
|
samer@0
|
133
|
samer@0
|
134 double[] phi=sourceModel.getGradient();
|
samer@0
|
135 for (int i=0; i<_n; i++) _G[i] += w*(phi[i]*_s[i] - 1);
|
samer@0
|
136 }
|
samer@0
|
137
|
samer@0
|
138 public void flush()
|
samer@0
|
139 {
|
samer@0
|
140 if (batch==0) return;
|
samer@0
|
141
|
samer@0
|
142 G.changed();
|
samer@0
|
143
|
samer@0
|
144 // do differential learning on trace & diagonal of G
|
samer@0
|
145 double alpha=scaleRate.value/batch;
|
samer@0
|
146 double beta=stretchRate.value/batch;
|
samer@0
|
147
|
samer@0
|
148 // compute factors and scale each row of W
|
samer@0
|
149 double mu=Mathx.sum(_G)/n, dl;
|
samer@0
|
150 for (int i=0; i<_n; i++) {
|
samer@0
|
151 dl=alpha*mu+beta*(_G[i]-mu);
|
samer@0
|
152 double tmp=Math.exp(-dl);
|
samer@0
|
153 if (Double.isNaN(tmp)) throw new Error("alt: NaN"+i);
|
samer@0
|
154 __k[i]*=tmp; // instead of Mathx.mul(_W[i],tmp);
|
samer@0
|
155 }
|
samer@0
|
156 k.changed();
|
samer@0
|
157
|
samer@0
|
158 reset(); // ready for next batch
|
samer@0
|
159 }
|
samer@0
|
160 }
|
samer@0
|
161 }
|
samer@0
|
162
|
samer@0
|
163
|