comparison src/EM.cpp @ 151:fc06b6f33021

double -> float in EM (to test)
author Chris Cannam
date Wed, 14 May 2014 19:38:36 +0100
parents d2bc51cc7f57
children 6003a9af43af
comparison
equal deleted inserted replaced
150:d2bc51cc7f57 151:fc06b6f33021
29 using std::cerr; 29 using std::cerr;
30 using std::endl; 30 using std::endl;
31 31
32 using namespace breakfastquay; 32 using namespace breakfastquay;
33 33
34 static double epsilon = 1e-16; 34 static float epsilon = 1e-10;
35 35
36 EM::EM(bool useShifts) : 36 EM::EM(bool useShifts) :
37 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), 37 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
38 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), 38 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1),
39 m_binCount(SILVET_TEMPLATE_HEIGHT), 39 m_binCount(SILVET_TEMPLATE_HEIGHT),
43 //!!! consider a modest shift sparsity e.g. 1.1 43 //!!! consider a modest shift sparsity e.g. 1.1
44 m_sourceSparsity(1.3), 44 m_sourceSparsity(1.3),
45 m_lowestPitch(silvet_templates_lowest_note), 45 m_lowestPitch(silvet_templates_lowest_note),
46 m_highestPitch(silvet_templates_highest_note) 46 m_highestPitch(silvet_templates_highest_note)
47 { 47 {
48 m_pitches = allocate<double>(m_noteCount); 48 m_pitches = allocate<float>(m_noteCount);
49 m_updatePitches = allocate<double>(m_noteCount); 49 m_updatePitches = allocate<float>(m_noteCount);
50 for (int n = 0; n < m_noteCount; ++n) { 50 for (int n = 0; n < m_noteCount; ++n) {
51 m_pitches[n] = drand48(); 51 m_pitches[n] = drand48();
52 } 52 }
53 53
54 if (useShifts) { 54 if (useShifts) {
55 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount); 55 m_shifts = allocate_channels<float>(m_shiftCount, m_noteCount);
56 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount); 56 m_updateShifts = allocate_channels<float>(m_shiftCount, m_noteCount);
57 for (int f = 0; f < m_shiftCount; ++f) { 57 for (int f = 0; f < m_shiftCount; ++f) {
58 for (int n = 0; n < m_noteCount; ++n) { 58 for (int n = 0; n < m_noteCount; ++n) {
59 m_shifts[f][n] = drand48(); 59 m_shifts[f][n] = drand48();
60 } 60 }
61 } 61 }
62 } else { 62 } else {
63 m_shifts = 0; 63 m_shifts = 0;
64 m_updateShifts = 0; 64 m_updateShifts = 0;
65 } 65 }
66 66
67 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount); 67 m_sources = allocate_channels<float>(m_sourceCount, m_noteCount);
68 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount); 68 m_updateSources = allocate_channels<float>(m_sourceCount, m_noteCount);
69 for (int i = 0; i < m_sourceCount; ++i) { 69 for (int i = 0; i < m_sourceCount; ++i) {
70 for (int n = 0; n < m_noteCount; ++n) { 70 for (int n = 0; n < m_noteCount; ++n) {
71 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); 71 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
72 } 72 }
73 } 73 }
74 74
75 m_estimate = allocate<double>(m_binCount); 75 m_estimate = allocate<float>(m_binCount);
76 m_q = allocate<double>(m_binCount); 76 m_q = allocate<float>(m_binCount);
77 } 77 }
78 78
79 EM::~EM() 79 EM::~EM()
80 { 80 {
81 deallocate(m_q); 81 deallocate(m_q);
102 rangeFor(instrument, minPitch, maxPitch); 102 rangeFor(instrument, minPitch, maxPitch);
103 return (pitch >= minPitch && pitch <= maxPitch); 103 return (pitch >= minPitch && pitch <= maxPitch);
104 } 104 }
105 105
106 void 106 void
107 EM::normaliseColumn(double *column, int size) 107 EM::normaliseColumn(float *column, int size)
108 { 108 {
109 double sum = v_sum(column, size); 109 float sum = v_sum(column, size);
110 v_scale(column, 1.0 / sum, size); 110 v_scale(column, 1.0 / sum, size);
111 } 111 }
112 112
113 void 113 void
114 EM::normaliseGrid(double **grid, int size1, int size2) 114 EM::normaliseGrid(float **grid, int size1, int size2)
115 { 115 {
116 double *denominators = allocate_and_zero<double>(size2); 116 float *denominators = allocate_and_zero<float>(size2);
117 117
118 for (int i = 0; i < size1; ++i) { 118 for (int i = 0; i < size1; ++i) {
119 for (int j = 0; j < size2; ++j) { 119 for (int j = 0; j < size2; ++j) {
120 denominators[j] += grid[i][j]; 120 denominators[j] += grid[i][j];
121 } 121 }
129 } 129 }
130 130
131 void 131 void
132 EM::iterate(const double *column) 132 EM::iterate(const double *column)
133 { 133 {
134 double *norm = allocate<double>(m_binCount); 134 float *norm = allocate<float>(m_binCount);
135 v_copy(norm, column, m_binCount); 135 v_convert(norm, column, m_binCount);
136 normaliseColumn(norm, m_binCount); 136 normaliseColumn(norm, m_binCount);
137 expectation(norm); 137 expectation(norm);
138 maximisation(norm); 138 maximisation(norm);
139 deallocate(norm); 139 deallocate(norm);
140 } 140 }
141 141
142 const double * 142 const float *
143 EM::templateFor(int instrument, int note, int shift) 143 EM::templateFor(int instrument, int note, int shift)
144 { 144 {
145 if (m_shifts) { 145 if (m_shifts) {
146 return silvet_templates[instrument].data[note] + shift; 146 return silvet_templates[instrument].data[note] + shift;
147 } else { 147 } else {
149 SILVET_TEMPLATE_MAX_SHIFT; 149 SILVET_TEMPLATE_MAX_SHIFT;
150 } 150 }
151 } 151 }
152 152
153 void 153 void
154 EM::expectation(const double *column) 154 EM::expectation(const float *column)
155 { 155 {
156 // cerr << "."; 156 // cerr << ".";
157 157
158 v_set(m_estimate, epsilon, m_binCount); 158 v_set(m_estimate, epsilon, m_binCount);
159 159
160 for (int f = 0; f < m_shiftCount; ++f) { 160 for (int f = 0; f < m_shiftCount; ++f) {
161 161
162 const double *shiftIn = m_shifts ? m_shifts[f] : 0; 162 const float *shiftIn = m_shifts ? m_shifts[f] : 0;
163 163
164 for (int i = 0; i < m_sourceCount; ++i) { 164 for (int i = 0; i < m_sourceCount; ++i) {
165 165
166 const double *sourceIn = m_sources[i]; 166 const float *sourceIn = m_sources[i];
167 167
168 int lowest, highest; 168 int lowest, highest;
169 rangeFor(i, lowest, highest); 169 rangeFor(i, lowest, highest);
170 170
171 for (int n = lowest; n <= highest; ++n) { 171 for (int n = lowest; n <= highest; ++n) {
172 172
173 const double source = sourceIn[n]; 173 const float source = sourceIn[n];
174 const double shift = shiftIn ? shiftIn[n] : 1.0; 174 const float shift = shiftIn ? shiftIn[n] : 1.0;
175 const double pitch = m_pitches[n]; 175 const float pitch = m_pitches[n];
176 176
177 const double factor = pitch * source * shift; 177 const float factor = pitch * source * shift;
178 const double *w = templateFor(i, n, f); 178 const float *w = templateFor(i, n, f);
179 179
180 v_add_with_gain(m_estimate, w, factor, m_binCount); 180 v_add_with_gain(m_estimate, w, factor, m_binCount);
181 } 181 }
182 } 182 }
183 } 183 }
189 m_q[i] = column[i] / m_estimate[i]; 189 m_q[i] = column[i] / m_estimate[i];
190 } 190 }
191 } 191 }
192 192
193 void 193 void
194 EM::maximisation(const double *column) 194 EM::maximisation(const float *column)
195 { 195 {
196 v_set(m_updatePitches, epsilon, m_noteCount); 196 v_set(m_updatePitches, epsilon, m_noteCount);
197 197
198 for (int i = 0; i < m_sourceCount; ++i) { 198 for (int i = 0; i < m_sourceCount; ++i) {
199 v_set(m_updateSources[i], epsilon, m_noteCount); 199 v_set(m_updateSources[i], epsilon, m_noteCount);
203 for (int i = 0; i < m_shiftCount; ++i) { 203 for (int i = 0; i < m_shiftCount; ++i) {
204 v_set(m_updateShifts[i], epsilon, m_noteCount); 204 v_set(m_updateShifts[i], epsilon, m_noteCount);
205 } 205 }
206 } 206 }
207 207
208 double *contributions = allocate<double>(m_binCount); 208 float *contributions = allocate<float>(m_binCount);
209 209
210 for (int f = 0; f < m_shiftCount; ++f) { 210 for (int f = 0; f < m_shiftCount; ++f) {
211 211
212 const double *shiftIn = m_shifts ? m_shifts[f] : 0; 212 const float *shiftIn = m_shifts ? m_shifts[f] : 0;
213 double *shiftOut = m_shifts ? m_updateShifts[f] : 0; 213 float *shiftOut = m_shifts ? m_updateShifts[f] : 0;
214 214
215 for (int i = 0; i < m_sourceCount; ++i) { 215 for (int i = 0; i < m_sourceCount; ++i) {
216 216
217 const double *sourceIn = m_sources[i]; 217 const float *sourceIn = m_sources[i];
218 double *sourceOut = m_updateSources[i]; 218 float *sourceOut = m_updateSources[i];
219 219
220 int lowest, highest; 220 int lowest, highest;
221 rangeFor(i, lowest, highest); 221 rangeFor(i, lowest, highest);
222 222
223 for (int n = lowest; n <= highest; ++n) { 223 for (int n = lowest; n <= highest; ++n) {
224 224
225 const double shift = shiftIn ? shiftIn[n] : 1.0; 225 const float shift = shiftIn ? shiftIn[n] : 1.0;
226 const double source = sourceIn[n]; 226 const float source = sourceIn[n];
227 const double pitch = m_pitches[n]; 227 const float pitch = m_pitches[n];
228 228
229 const double factor = pitch * source * shift; 229 const float factor = pitch * source * shift;
230 const double *w = templateFor(i, n, f); 230 const float *w = templateFor(i, n, f);
231 231
232 v_copy(contributions, w, m_binCount); 232 v_copy(contributions, w, m_binCount);
233 v_multiply(contributions, m_q, m_binCount); 233 v_multiply(contributions, m_q, m_binCount);
234 234
235 double total = factor * v_sum(contributions, m_binCount); 235 float total = factor * v_sum(contributions, m_binCount);
236 236
237 m_updatePitches[n] += total; 237 m_updatePitches[n] += total;
238 sourceOut[n] += total; 238 sourceOut[n] += total;
239 239
240 if (shiftOut) { 240 if (shiftOut) {