samer@0
|
1 /*
|
samer@0
|
2 * Copyright (c) 2000, Samer Abdallah, King's College London.
|
samer@0
|
3 * All rights reserved.
|
samer@0
|
4 *
|
samer@0
|
5 * This software is provided AS iS and WITHOUT ANY WARRANTY;
|
samer@0
|
6 * without even the implied warranty of MERCHANTABILITY or
|
samer@0
|
7 * FITNESS FOR A PARTICULAR PURPOSE.
|
samer@0
|
8 */
|
samer@0
|
9
|
samer@0
|
10 package samer.models;
|
samer@0
|
11 import samer.core.*;
|
samer@0
|
12 import samer.core.types.*;
|
samer@0
|
13 import samer.tools.*;
|
samer@0
|
14 import samer.maths.*;
|
samer@0
|
15 import samer.maths.*;
|
samer@0
|
16 import samer.maths.opt.*;
|
samer@0
|
17
|
samer@0
|
18 public class NoisyICA extends NamedTask implements Model
|
samer@0
|
19 {
|
samer@0
|
20 Vec x; // data (input)
|
samer@0
|
21 int n, m; // sizes (data, sources)
|
samer@0
|
22 Matrix A; // basis matrix
|
samer@0
|
23 Model Ms, Me; // source and noise models
|
samer@0
|
24 VVector s, z, e; // sources, reconstruction, error
|
samer@0
|
25 VDouble E; // energy
|
samer@0
|
26
|
samer@0
|
27 Task inference=new NullTask();
|
samer@0
|
28
|
samer@0
|
29
|
samer@0
|
30 // ----- variables used in computations -----------------
|
samer@0
|
31
|
samer@0
|
32 private double [] e_;
|
samer@0
|
33 private double [] x_;
|
samer@0
|
34 private double [] z_;
|
samer@0
|
35 private double [] s_;
|
samer@0
|
36 private VectorFunctionOfVector tA, tAt;
|
samer@0
|
37
|
samer@0
|
38
|
samer@0
|
39 public NoisyICA(Vec in,int outs) { this(new Node("noisyica"),in.size(),outs); setInput(in); }
|
samer@0
|
40 public NoisyICA(int ins,int outs) { this(new Node("noisyica"),ins,outs); }
|
samer@0
|
41 public NoisyICA(Node node, int inputs, int outputs)
|
samer@0
|
42 {
|
samer@0
|
43 super(node);
|
samer@0
|
44 Shell.push(node);
|
samer@0
|
45
|
samer@0
|
46 n = inputs;
|
samer@0
|
47 m = outputs;
|
samer@0
|
48
|
samer@0
|
49 s = new VVector("s",m);
|
samer@0
|
50 z = new VVector("z",n);
|
samer@0
|
51 e = new VVector("e",n);
|
samer@0
|
52 E = new VDouble("E");
|
samer@0
|
53 A = new Matrix("A",n,m);
|
samer@0
|
54 A.identity();
|
samer@0
|
55
|
samer@0
|
56 e_ = e.array();
|
samer@0
|
57 z_ = z.array();
|
samer@0
|
58 s_ = s.array();
|
samer@0
|
59 tA = new MatrixTimesVector(A);
|
samer@0
|
60 tAt= new MatrixTransposeTimesVector(A);
|
samer@0
|
61
|
samer@0
|
62 Shell.pop();
|
samer@0
|
63 }
|
samer@0
|
64
|
samer@0
|
65 public Model getSourceModel() { return Ms; }
|
samer@0
|
66 public Model getNoiseModel() { return Me; }
|
samer@0
|
67 public void setSourceModel(Model m) { Ms=m; Ms.setInput(s); }
|
samer@0
|
68 public void setNoiseModel(Model m) { Me=m; Me.setInput(e); }
|
samer@0
|
69 public void setInput(Vec in) { x=in; x_ = x.array(); }
|
samer@0
|
70 public Matrix basisMatrix() { return A; }
|
samer@0
|
71 public VVector output() { return s; }
|
samer@0
|
72 public VVector error() { return e; }
|
samer@0
|
73 public VVector reconstruction() { return z; }
|
samer@0
|
74
|
samer@0
|
75 public int getSize() { return n; }
|
samer@0
|
76
|
samer@0
|
77 public void setInferenceTask(Task t) { inference=t; }
|
samer@0
|
78
|
samer@0
|
79 public void infer() {
|
samer@0
|
80 try { inference.run(); }
|
samer@0
|
81 catch (Exception ex) {
|
samer@0
|
82 Shell.trace("error: "+ex);
|
samer@0
|
83 ex.printStackTrace();
|
samer@0
|
84 throw new Error("inference failed: "+ex); }
|
samer@0
|
85 tA.apply(s_,z_); Mathx.sub(e_,x_,z_);
|
samer@0
|
86 e.changed(); z.changed(); s.changed();
|
samer@0
|
87 }
|
samer@0
|
88
|
samer@0
|
89 public void compute() {
|
samer@0
|
90 E.set(Me.getEnergy() + Ms.getEnergy());
|
samer@0
|
91 // what about dE/dx?
|
samer@0
|
92 }
|
samer@0
|
93 public double getEnergy() { return E.value; }
|
samer@0
|
94 public double [] getGradient() { return null; } // this is wrong
|
samer@0
|
95
|
samer@0
|
96 /** get basis vector norms into given array */
|
samer@0
|
97 public void norms(double [] na)
|
samer@0
|
98 {
|
samer@0
|
99 double [][] M=A.getArray();
|
samer@0
|
100
|
samer@0
|
101 Mathx.zero(na);
|
samer@0
|
102 for (int i=0; i<n; i++) {
|
samer@0
|
103 double [] Mi=M[i];
|
samer@0
|
104 for (int j=0; j<m; j++) {
|
samer@0
|
105 na[j] += Mi[j]*Mi[j];
|
samer@0
|
106 }
|
samer@0
|
107 }
|
samer@0
|
108 }
|
samer@0
|
109
|
samer@0
|
110 public void run() { infer(); }
|
samer@0
|
111 public void dispose()
|
samer@0
|
112 {
|
samer@0
|
113 s.dispose();
|
samer@0
|
114 A.dispose();
|
samer@0
|
115 e.dispose();
|
samer@0
|
116 z.dispose();
|
samer@0
|
117 E.dispose();
|
samer@0
|
118 tA.dispose();
|
samer@0
|
119 tAt.dispose();
|
samer@0
|
120
|
samer@0
|
121 super.dispose();
|
samer@0
|
122 }
|
samer@0
|
123
|
samer@0
|
124 public Functionx functionx() { return null; }
|
samer@0
|
125
|
samer@0
|
126 public Functionx posterior() {
|
samer@0
|
127 // returns Functionx which evaluates E and dE/ds at current x
|
samer@0
|
128 return new Functionx() {
|
samer@0
|
129 Functionx fMs=Ms.functionx();
|
samer@0
|
130 Functionx fMe=Me.functionx();
|
samer@0
|
131 double [] e=new double[n];
|
samer@0
|
132 double [] ge=new double[n];
|
samer@0
|
133 double [] gs=new double[m];
|
samer@0
|
134
|
samer@0
|
135 public void dispose() { fMs.dispose(); fMe.dispose(); }
|
samer@0
|
136 public void evaluate( Datum P) { P.f=evaluate(P.x,P.g); }
|
samer@0
|
137 public double evaluate( double [] s, double [] g)
|
samer@0
|
138 {
|
samer@0
|
139 tA.apply(s,z_); Mathx.sub(e,x_,z_);
|
samer@0
|
140 double Ee=fMe.evaluate(e,ge); tAt.apply(ge,gs); // gs=A'*gamma(e)
|
samer@0
|
141 double Es=fMs.evaluate(s,g); Mathx.sub(g,gs); // g=gamma(s)-ge
|
samer@0
|
142 return Es+Ee;
|
samer@0
|
143 }
|
samer@0
|
144 };
|
samer@0
|
145 }
|
samer@0
|
146
|
samer@0
|
147 public Trainer learnHebbian() { return new MatrixTrainer(e_,A,s_); }
|
samer@0
|
148
|
samer@0
|
149 public Trainer learnLewickiSejnowski()
|
samer@0
|
150 {
|
samer@0
|
151 final double [] h =new double[n];
|
samer@0
|
152 final double [] f =new double[m];
|
samer@0
|
153 return new MatrixTrainer(h,A,s_) {
|
samer@0
|
154
|
samer@0
|
155 public void accumulate(double w) {
|
samer@0
|
156 // tAt.apply(Me.getGradient(),f); tA.apply(f,h);
|
samer@0
|
157 tA.apply(Ms.getGradient(),h);
|
samer@0
|
158 super.accumulate(w);
|
samer@0
|
159 }
|
samer@0
|
160 public void flush() { // flush with decay
|
samer@0
|
161 for (int j=0; j<this.m; j++)
|
samer@0
|
162 for (int i=0; i<this.n; i++)
|
samer@0
|
163 _T[i][j] -= count*_A[i][j];
|
samer@0
|
164 super.flush();
|
samer@0
|
165 }
|
samer@0
|
166 };
|
samer@0
|
167 }
|
samer@0
|
168
|
samer@0
|
169 public Trainer learnDecayWhenActive()
|
samer@0
|
170 {
|
samer@0
|
171 final double [] h =new double[n];
|
samer@0
|
172 return new MatrixTrainer(h,A,s_) {
|
samer@0
|
173 VDouble th=new VDouble("threshold",0.01);
|
samer@0
|
174 public void accumulate(double w) {
|
samer@0
|
175 // perhaps would like to get info out of optimiser here
|
samer@0
|
176 tA.apply(Ms.getGradient(),h);
|
samer@0
|
177 super.accumulate(w);
|
samer@0
|
178
|
samer@0
|
179 // decay when active part
|
samer@0
|
180 double thresh=th.value;
|
samer@0
|
181 for (int j=0; j<this.m; j++)
|
samer@0
|
182 if (isActive(s_[j],thresh))
|
samer@0
|
183 for (int i=0; i<this.n; i++)
|
samer@0
|
184 _T[i][j] -= _A[i][j];
|
samer@0
|
185 }
|
samer@0
|
186
|
samer@0
|
187 public VDouble getThreshold() { return th; }
|
samer@0
|
188 };
|
samer@0
|
189 }
|
samer@0
|
190 static boolean isActive(double s, double t) { return s>t || s<-t; }
|
samer@0
|
191 }
|