samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2000, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.types.*;
|
samer@0
|
13 import samer.tools.*;
|
samer@0
|
14 import samer.maths.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16 import samer.maths.opt.*;
|
samer@0
|
17
|
samer@0
|
18 public class VarianceICA extends NamedTask implements Model
|
samer@0
|
19 {
|
samer@0
|
20 Vec x; // data (input)
|
samer@0
|
21 int n, m; // sizes (data, sources)
|
samer@0
|
22 Matrix A; // basis matrix
|
samer@0
|
23 Model Ms; // source model
|
samer@0
|
24 VVector s, z, e; // sources, reconstruction, error
|
samer@0
|
25 VDouble E; // energy
|
samer@0
|
26
|
samer@0
|
27 Task inference=new NullTask();
|
samer@0
|
28
|
samer@0
|
29
|
samer@0
|
30 // ----- variables used in computations -----------------
|
samer@0
|
31
|
samer@0
|
32 private double [] e_;
|
samer@0
|
33 private double [] x_;
|
samer@0
|
34 private double [] z_;
|
samer@0
|
35 private double [] s_;
|
samer@0
|
36 private VectorFunctionOfVector tA, tAt;
|
samer@0
|
37
|
samer@0
|
38
|
samer@0
|
39 public VarianceICA(Node node, int inputs, int outputs)
|
samer@0
|
40 {
|
samer@0
|
41 super(node);
|
samer@0
|
42 Shell.push(node);
|
samer@0
|
43
|
samer@0
|
44 n = inputs;
|
samer@0
|
45 m = outputs;
|
samer@0
|
46
|
samer@0
|
47 s = new VVector("s",m);
|
samer@0
|
48 z = new VVector("z",n);
|
samer@0
|
49 e = new VVector("e",n);
|
samer@0
|
50 E = new VDouble("E");
|
samer@0
|
51 A = new Matrix("A",n,m);
|
samer@0
|
52 A.identity();
|
samer@0
|
53
|
samer@0
|
54 e_ = e.array();
|
samer@0
|
55 z_ = z.array();
|
samer@0
|
56 s_ = s.array();
|
samer@0
|
57 tA = new MatrixTimesVector(A);
|
samer@0
|
58 tAt= new MatrixTransposeTimesVector(A);
|
samer@0
|
59
|
samer@0
|
60 Shell.pop();
|
samer@0
|
61 }
|
samer@0
|
62
|
samer@0
|
63 public Model getSourceModel() { return Ms; }
|
samer@0
|
64 public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
|
samer@0
|
65 public void setInput(Vec in) { x=in; x_ = x.array(); }
|
samer@0
|
66 public Matrix basisMatrix() { return A; }
|
samer@0
|
67 public VVector output() { return s; }
|
samer@0
|
68 public VVector error() { return e; }
|
samer@0
|
69
|
samer@0
|
70 public int getSize() { return n; }
|
samer@0
|
71
|
samer@0
|
72 public void setInferenceTask(Task t) { inference=t; }
|
samer@0
|
73
|
samer@0
|
74 public void infer() {
|
samer@0
|
75 try { inference.run(); }
|
samer@0
|
76 catch (Exception ex) { throw new Error("inference failed"); }
|
samer@0
|
77 tA.apply(s_,z_);
|
samer@0
|
78 Mathx.sub(e_,x_,z_);
|
samer@0
|
79 Mathx.div(e_,z_);
|
samer@0
|
80 Mathx.div(e_,z_);
|
samer@0
|
81 e.changed(); z.changed(); s.changed();
|
samer@0
|
82 }
|
samer@0
|
83
|
samer@0
|
84 public void compute() {
|
samer@0
|
85 double EE=0;
|
samer@0
|
86 for (int i=0; i<n; i++) {
|
samer@0
|
87 double t=x_[i]/z_[i];
|
samer@0
|
88 EE += t - Math.log(t);
|
samer@0
|
89 }
|
samer@0
|
90 E.set(EE - n + Ms.getEnergy());
|
samer@0
|
91 // what about dE/dx?
|
samer@0
|
92 }
|
samer@0
|
93 public double getEnergy() { return E.value; }
|
samer@0
|
94 public double [] getGradient() { return null; } // this is wrong
|
samer@0
|
95
|
samer@0
|
96 public void run() { infer(); }
|
samer@0
|
97 public void dispose()
|
samer@0
|
98 {
|
samer@0
|
99 s.dispose();
|
samer@0
|
100 A.dispose();
|
samer@0
|
101 e.dispose();
|
samer@0
|
102 z.dispose();
|
samer@0
|
103 E.dispose();
|
samer@0
|
104 tA.dispose();
|
samer@0
|
105 tAt.dispose();
|
samer@0
|
106
|
samer@0
|
107 super.dispose();
|
samer@0
|
108 }
|
samer@0
|
109
|
samer@0
|
110 public Functionx functionx() { return null; }
|
samer@0
|
111
|
samer@0
|
112 public Functionx posterior() {
|
samer@0
|
113 // returns Functionx which evaluates E and dE/ds at current x
|
samer@0
|
114 return new Functionx() {
|
samer@0
|
115 Functionx fMs=Ms.functionx();
|
samer@0
|
116 double [] ge=new double[n];
|
samer@0
|
117
|
samer@0
|
118 public void dispose() { fMs.dispose(); }
|
samer@0
|
119 public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
|
samer@0
|
120 public double evaluate( double [] s, double [] g)
|
samer@0
|
121 {
|
samer@0
|
122 double E=0;
|
samer@0
|
123
|
samer@0
|
124 tA.apply(s,z_); // z=As
|
samer@0
|
125 for (int i=0; i<n; i++) {
|
samer@0
|
126 double t=x_[i]/z_[i];
|
samer@0
|
127 g[i] = (t-1)/z_[i];
|
samer@0
|
128 E += t - Math.log(t);
|
samer@0
|
129 }
|
samer@0
|
130 tAt.apply(g,ge); // ge=A'*g
|
samer@0
|
131 E+=fMs.evaluate(s,g);
|
samer@0
|
132 Mathx.sub(g,ge);
|
samer@0
|
133 return E;
|
samer@0
|
134 }
|
samer@0
|
135 };
|
samer@0
|
136 }
|
samer@0
|
137
|
samer@0
|
138 public Trainer learnHebbian() {
|
samer@0
|
139 Shell.push(node);
|
samer@0
|
140 try { return new MatrixTrainer(e_,A,s_); }
|
samer@0
|
141 finally { Shell.pop(); }
|
samer@0
|
142 }
|
samer@0
|
143
|
samer@0
|
144 public Trainer learnLewickiSejnowski()
|
samer@0
|
145 {
|
samer@0
|
146 Shell.push(node);
|
samer@0
|
147 try {
|
samer@0
|
148 final double [] h =new double[n];
|
samer@0
|
149 final double [] f =new double[m];
|
samer@0
|
150 return new MatrixTrainer(h,A,s_) {
|
samer@0
|
151
|
samer@0
|
152 public void accumulate(double w) {
|
samer@0
|
153 tA.apply(Ms.getGradient(),h);
|
samer@0
|
154 super.accumulate(w);
|
samer@0
|
155 }
|
samer@0
|
156 public void flush() { // flush with decay
|
samer@0
|
157 for (int j=0; j<this.m; j++)
|
samer@0
|
158 for (int i=0; i<this.n; i++)
|
samer@0
|
159 _T[i][j] -= count*_A[i][j];
|
samer@0
|
160 super.flush();
|
samer@0
|
161 }
|
samer@0
|
162 };
|
samer@0
|
163 } finally { Shell.pop(); }
|
samer@0
|
164 }
|
samer@0
|
165
|
samer@0
|
166 public Trainer learnDecayWhenActive()
|
samer@0
|
167 {
|
samer@0
|
168 Shell.push(node);
|
samer@0
|
169 try {
|
samer@0
|
170 final double [] h =new double[n];
|
samer@0
|
171 return new MatrixTrainer(h,A,s_) {
|
samer@0
|
172 VDouble th=new VDouble("threshold",0.01);
|
samer@0
|
173 public void accumulate(double w) {
|
samer@0
|
174 tA.apply(Ms.getGradient(),h);
|
samer@0
|
175 super.accumulate(w);
|
samer@0
|
176
|
samer@0
|
177 // decay when active part
|
samer@0
|
178 double thresh=th.value;
|
samer@0
|
179 for (int j=0; j<this.m; j++)
|
samer@0
|
180 if (isActive(s_[j],thresh))
|
samer@0
|
181 for (int i=0; i<this.n; i++)
|
samer@0
|
182 _T[i][j] -= _A[i][j];
|
samer@0
|
183 }
|
samer@0
|
184 };
|
samer@0
|
185 } finally { Shell.pop(); }
|
samer@0
|
186 }
|
samer@0
|
187 static boolean isActive(double s, double t) { return s>t || s<-t; }
|
samer@0
|
188 }
|