Mercurial > hg > silvet
comparison src/EM.cpp @ 36:74b77a4d6552
Fill out the EM logic
author | Chris Cannam |
---|---|
date | Fri, 04 Apr 2014 17:48:06 +0100 |
parents | 461d94ed3816 |
children | 947996aac974 |
comparison
equal
deleted
inserted
replaced
35:461d94ed3816 | 36:74b77a4d6552 |
---|---|
15 | 15 |
16 #include "EM.h" | 16 #include "EM.h" |
17 | 17 |
18 #include "data/include/templates.h" | 18 #include "data/include/templates.h" |
19 | 19 |
20 #include <cstdlib> | |
21 | |
22 #include <iostream> | |
23 | |
24 #include <vector> | |
25 | |
26 using std::vector; | |
27 using std::cerr; | |
28 using std::endl; | |
29 | |
20 static double epsilon = 1e-16; | 30 static double epsilon = 1e-16; |
21 | 31 |
22 EM::EM() : | 32 EM::EM() : |
23 m_notes(SILVET_TEMPLATE_NOTE_COUNT), | 33 m_notes(SILVET_TEMPLATE_NOTE_COUNT), |
24 m_bins(SILVET_TEMPLATE_HEIGHT), | 34 m_bins(SILVET_TEMPLATE_HEIGHT), |
25 m_instruments(SILVET_TEMPLATE_COUNT) | 35 m_instruments(SILVET_TEMPLATE_COUNT) |
26 { | 36 { |
37 cerr << "init!" << endl; | |
38 | |
27 m_lowest = 0; | 39 m_lowest = 0; |
28 m_highest = m_notes - 1; | 40 m_highest = m_notes - 1; |
29 | 41 |
30 for (int i = 0; i < m_instruments; ++i) { | 42 for (int i = 0; i < m_instruments; ++i) { |
31 if (i == 0 || silvet_templates[i].lowest < m_lowest) { | 43 if (i == 0 || silvet_templates[i].lowest < m_lowest) { |
36 } | 48 } |
37 } | 49 } |
38 | 50 |
39 m_pitches = V(m_notes); | 51 m_pitches = V(m_notes); |
40 | 52 |
41 for (int n = 0; n < m_notes; ++i) { | 53 for (int n = 0; n < m_notes; ++n) { |
42 m_pitches[n] = drand48(); | 54 m_pitches[n] = drand48(); |
43 } | 55 } |
44 | 56 |
45 m_sources = Grid(m_instruments); | 57 m_sources = Grid(m_instruments); |
46 | 58 |
49 for (int n = 0; n < m_notes; ++n) { | 61 for (int n = 0; n < m_notes; ++n) { |
50 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); | 62 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); |
51 } | 63 } |
52 } | 64 } |
53 | 65 |
66 m_estimate = V(m_bins); | |
54 m_q = V(m_bins); | 67 m_q = V(m_bins); |
55 | 68 |
56 for (int w = 0; w < m_bins; ++w) { | 69 cerr << "(init done)" << endl; |
57 m_q[w] = epsilon; | |
58 } | |
59 } | 70 } |
60 | 71 |
61 EM::~EM() | 72 EM::~EM() |
62 { | 73 { |
63 } | 74 } |
67 { | 78 { |
68 return (note >= silvet_templates[instrument].lowest && | 79 return (note >= silvet_templates[instrument].lowest && |
69 note <= silvet_templates[instrument].highest); | 80 note <= silvet_templates[instrument].highest); |
70 } | 81 } |
71 | 82 |
83 void | |
84 EM::normalise(V &column) | |
85 { | |
86 double sum = 0.0; | |
87 for (int i = 0; i < (int)column.size(); ++i) { | |
88 sum += column[i]; | |
89 } | |
90 for (int i = 0; i < (int)column.size(); ++i) { | |
91 column[i] /= sum; | |
92 } | |
93 } | |
94 | |
95 void | |
96 EM::iterate(V column) | |
97 { | |
98 normalise(column); | |
99 expectation(column); | |
100 maximisation(column); | |
101 } | |
102 | |
103 void | |
104 EM::expectation(const V &column) | |
105 { | |
106 cerr << "."; | |
107 | |
108 for (int i = 0; i < m_bins; ++i) { | |
109 m_estimate[i] = epsilon; | |
110 } | |
111 | |
112 for (int i = 0; i < m_instruments; ++i) { | |
113 for (int n = 0; n < m_notes; ++n) { | |
114 float *w = silvet_templates[i].data[n]; | |
115 double pitch = m_pitches[n]; | |
116 double source = m_sources[i][n]; | |
117 for (int j = 0; j < m_bins; ++j) { | |
118 m_estimate[j] += w[j] * pitch * source; | |
119 } | |
120 } | |
121 } | |
122 | |
123 for (int i = 0; i < m_bins; ++i) { | |
124 m_q[i] = column[i] / m_estimate[i]; | |
125 } | |
126 } | |
127 | |
128 void | |
129 EM::maximisation(const V &column) | |
130 { | |
131 V newPitches = m_pitches; | |
132 | |
133 for (int n = 0; n < m_notes; ++n) { | |
134 newPitches[n] = epsilon; | |
135 if (n >= m_lowest && n <= m_highest) { | |
136 for (int i = 0; i < m_instruments; ++i) { | |
137 float *w = silvet_templates[i].data[n]; | |
138 double pitch = m_pitches[n]; | |
139 double source = m_sources[i][n]; | |
140 for (int j = 0; j < m_bins; ++j) { | |
141 newPitches[n] += w[j] * m_q[j] * pitch * source; | |
142 } | |
143 } | |
144 } | |
145 } | |
146 normalise(newPitches); | |
147 | |
148 Grid newSources = m_sources; | |
149 | |
150 for (int i = 0; i < m_instruments; ++i) { | |
151 for (int n = 0; n < m_notes; ++n) { | |
152 newSources[i][n] = epsilon; | |
153 if (inRange(i, n)) { | |
154 float *w = silvet_templates[i].data[n]; | |
155 for (int j = 0; j < m_bins; ++j) { | |
156 newSources[i][n] += | |
157 w[j] * m_q[j] * m_pitches[n] * m_sources[i][n]; | |
158 } | |
159 } | |
160 } | |
161 normalise(newSources[i]); | |
162 } | |
163 | |
164 m_pitches = newPitches; | |
165 m_sources = newSources; | |
166 } | |
167 | |
168 void | |
169 EM::report() | |
170 { | |
171 vector<int> sounding; | |
172 for (int n = 0; n < m_notes; ++n) { | |
173 if (m_pitches[n] > 0.05) { | |
174 sounding.push_back(n); | |
175 } | |
176 } | |
177 cerr << " sounding: "; | |
178 for (int i = 0; i < (int)sounding.size(); ++i) { | |
179 cerr << sounding[i] << " "; | |
180 } | |
181 cerr << endl; | |
182 } | |
183 |