Mercurial > hg > silvet
comparison src/EM.cpp @ 97:840c0d703bbb timing
Use single-precision floats throughout EM code
author | Chris Cannam |
---|---|
date | Tue, 06 May 2014 14:45:16 +0100 |
parents | a6e136aaa202 |
children |
comparison
equal
deleted
inserted
replaced
90:f1116eb464f9 | 97:840c0d703bbb |
---|---|
26 | 26 |
27 using std::vector; | 27 using std::vector; |
28 using std::cerr; | 28 using std::cerr; |
29 using std::endl; | 29 using std::endl; |
30 | 30 |
31 static double epsilon = 1e-16; | 31 static float epsilon = 1e-8; |
32 | 32 |
33 EM::EM() : | 33 EM::EM() : |
34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), | 34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), |
35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1), | 35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1), |
36 m_binCount(SILVET_TEMPLATE_HEIGHT), | 36 m_binCount(SILVET_TEMPLATE_HEIGHT), |
85 } | 85 } |
86 | 86 |
87 void | 87 void |
88 EM::normaliseColumn(V &column) | 88 EM::normaliseColumn(V &column) |
89 { | 89 { |
90 double sum = 0.0; | 90 float sum = 0.0; |
91 for (int i = 0; i < (int)column.size(); ++i) { | 91 for (int i = 0; i < (int)column.size(); ++i) { |
92 sum += column[i]; | 92 sum += column[i]; |
93 } | 93 } |
94 for (int i = 0; i < (int)column.size(); ++i) { | 94 for (int i = 0; i < (int)column.size(); ++i) { |
95 column[i] /= sum; | 95 column[i] /= sum; |
113 } | 113 } |
114 } | 114 } |
115 } | 115 } |
116 | 116 |
117 void | 117 void |
118 EM::iterate(V column) | 118 EM::iterate(const vector<double> &column) |
119 { | 119 { |
120 normaliseColumn(column); | 120 V norm(column.begin(), column.end()); |
121 expectation(column); | 121 normaliseColumn(norm); |
122 maximisation(column); | 122 expectation(norm); |
123 } | 123 maximisation(norm); |
124 | 124 } |
125 const double * | 125 |
126 const float * | |
126 EM::templateFor(int instrument, int note, int shift) | 127 EM::templateFor(int instrument, int note, int shift) |
127 { | 128 { |
128 return silvet_templates[instrument].data[note] + shift; | 129 return silvet_templates[instrument].data[note] + shift; |
129 } | 130 } |
130 | 131 |
137 m_estimate[i] = epsilon; | 138 m_estimate[i] = epsilon; |
138 } | 139 } |
139 | 140 |
140 for (int i = 0; i < m_instrumentCount; ++i) { | 141 for (int i = 0; i < m_instrumentCount; ++i) { |
141 for (int n = 0; n < m_noteCount; ++n) { | 142 for (int n = 0; n < m_noteCount; ++n) { |
142 const double pitch = m_pitches[n]; | 143 const float pitch = m_pitches[n]; |
143 const double source = m_sources[i][n]; | 144 const float source = m_sources[i][n]; |
144 for (int f = 0; f < m_shiftCount; ++f) { | 145 for (int f = 0; f < m_shiftCount; ++f) { |
145 const double *w = templateFor(i, n, f); | 146 const float *w = templateFor(i, n, f); |
146 const double shift = m_shifts[f][n]; | 147 const float shift = m_shifts[f][n]; |
147 const double factor = pitch * source * shift; | 148 const float factor = pitch * source * shift; |
148 for (int j = 0; j < m_binCount; ++j) { | 149 for (int j = 0; j < m_binCount; ++j) { |
149 m_estimate[j] += w[j] * factor; | 150 m_estimate[j] += w[j] * factor; |
150 } | 151 } |
151 } | 152 } |
152 } | 153 } |
164 Grid newShifts(m_shiftCount, V(m_noteCount, epsilon)); | 165 Grid newShifts(m_shiftCount, V(m_noteCount, epsilon)); |
165 Grid newSources(m_instrumentCount, V(m_noteCount, epsilon)); | 166 Grid newSources(m_instrumentCount, V(m_noteCount, epsilon)); |
166 | 167 |
167 for (int n = 0; n < m_noteCount; ++n) { | 168 for (int n = 0; n < m_noteCount; ++n) { |
168 | 169 |
169 const double pitch = m_pitches[n]; | 170 const float pitch = m_pitches[n]; |
170 | 171 |
171 for (int f = 0; f < m_shiftCount; ++f) { | 172 for (int f = 0; f < m_shiftCount; ++f) { |
172 | 173 |
173 const double shift = m_shifts[f][n]; | 174 const float shift = m_shifts[f][n]; |
174 | 175 |
175 for (int i = 0; i < m_instrumentCount; ++i) { | 176 for (int i = 0; i < m_instrumentCount; ++i) { |
176 | 177 |
177 const double source = m_sources[i][n]; | 178 const float source = m_sources[i][n]; |
178 const double factor = pitch * source * shift; | 179 const float factor = pitch * source * shift; |
179 const double *w = templateFor(i, n, f); | 180 const float *w = templateFor(i, n, f); |
180 | 181 |
181 if (n >= m_lowestPitch && n <= m_highestPitch) { | 182 if (n >= m_lowestPitch && n <= m_highestPitch) { |
182 | 183 |
183 for (int j = 0; j < m_binCount; ++j) { | 184 for (int j = 0; j < m_binCount; ++j) { |
184 newPitches[n] += w[j] * m_q[j] * factor; | 185 newPitches[n] += w[j] * m_q[j] * factor; |