comparison src/EM.cpp @ 97:840c0d703bbb timing

Use single-precision floats throughout EM code
author Chris Cannam
date Tue, 06 May 2014 14:45:16 +0100
parents a6e136aaa202
children
comparison
equal deleted inserted replaced
90:f1116eb464f9 97:840c0d703bbb
26 26
27 using std::vector; 27 using std::vector;
28 using std::cerr; 28 using std::cerr;
29 using std::endl; 29 using std::endl;
30 30
31 static double epsilon = 1e-16; 31 static float epsilon = 1e-8;
32 32
33 EM::EM() : 33 EM::EM() :
34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), 34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT),
35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1), 35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1),
36 m_binCount(SILVET_TEMPLATE_HEIGHT), 36 m_binCount(SILVET_TEMPLATE_HEIGHT),
85 } 85 }
86 86
87 void 87 void
88 EM::normaliseColumn(V &column) 88 EM::normaliseColumn(V &column)
89 { 89 {
90 double sum = 0.0; 90 float sum = 0.0;
91 for (int i = 0; i < (int)column.size(); ++i) { 91 for (int i = 0; i < (int)column.size(); ++i) {
92 sum += column[i]; 92 sum += column[i];
93 } 93 }
94 for (int i = 0; i < (int)column.size(); ++i) { 94 for (int i = 0; i < (int)column.size(); ++i) {
95 column[i] /= sum; 95 column[i] /= sum;
113 } 113 }
114 } 114 }
115 } 115 }
116 116
117 void 117 void
118 EM::iterate(V column) 118 EM::iterate(const vector<double> &column)
119 { 119 {
120 normaliseColumn(column); 120 V norm(column.begin(), column.end());
121 expectation(column); 121 normaliseColumn(norm);
122 maximisation(column); 122 expectation(norm);
123 } 123 maximisation(norm);
124 124 }
125 const double * 125
126 const float *
126 EM::templateFor(int instrument, int note, int shift) 127 EM::templateFor(int instrument, int note, int shift)
127 { 128 {
128 return silvet_templates[instrument].data[note] + shift; 129 return silvet_templates[instrument].data[note] + shift;
129 } 130 }
130 131
137 m_estimate[i] = epsilon; 138 m_estimate[i] = epsilon;
138 } 139 }
139 140
140 for (int i = 0; i < m_instrumentCount; ++i) { 141 for (int i = 0; i < m_instrumentCount; ++i) {
141 for (int n = 0; n < m_noteCount; ++n) { 142 for (int n = 0; n < m_noteCount; ++n) {
142 const double pitch = m_pitches[n]; 143 const float pitch = m_pitches[n];
143 const double source = m_sources[i][n]; 144 const float source = m_sources[i][n];
144 for (int f = 0; f < m_shiftCount; ++f) { 145 for (int f = 0; f < m_shiftCount; ++f) {
145 const double *w = templateFor(i, n, f); 146 const float *w = templateFor(i, n, f);
146 const double shift = m_shifts[f][n]; 147 const float shift = m_shifts[f][n];
147 const double factor = pitch * source * shift; 148 const float factor = pitch * source * shift;
148 for (int j = 0; j < m_binCount; ++j) { 149 for (int j = 0; j < m_binCount; ++j) {
149 m_estimate[j] += w[j] * factor; 150 m_estimate[j] += w[j] * factor;
150 } 151 }
151 } 152 }
152 } 153 }
164 Grid newShifts(m_shiftCount, V(m_noteCount, epsilon)); 165 Grid newShifts(m_shiftCount, V(m_noteCount, epsilon));
165 Grid newSources(m_instrumentCount, V(m_noteCount, epsilon)); 166 Grid newSources(m_instrumentCount, V(m_noteCount, epsilon));
166 167
167 for (int n = 0; n < m_noteCount; ++n) { 168 for (int n = 0; n < m_noteCount; ++n) {
168 169
169 const double pitch = m_pitches[n]; 170 const float pitch = m_pitches[n];
170 171
171 for (int f = 0; f < m_shiftCount; ++f) { 172 for (int f = 0; f < m_shiftCount; ++f) {
172 173
173 const double shift = m_shifts[f][n]; 174 const float shift = m_shifts[f][n];
174 175
175 for (int i = 0; i < m_instrumentCount; ++i) { 176 for (int i = 0; i < m_instrumentCount; ++i) {
176 177
177 const double source = m_sources[i][n]; 178 const float source = m_sources[i][n];
178 const double factor = pitch * source * shift; 179 const float factor = pitch * source * shift;
179 const double *w = templateFor(i, n, f); 180 const float *w = templateFor(i, n, f);
180 181
181 if (n >= m_lowestPitch && n <= m_highestPitch) { 182 if (n >= m_lowestPitch && n <= m_highestPitch) {
182 183
183 for (int j = 0; j < m_binCount; ++j) { 184 for (int j = 0; j < m_binCount; ++j) {
184 newPitches[n] += w[j] * m_q[j] * factor; 185 newPitches[n] += w[j] * m_q[j] * factor;