Mercurial > hg > silvet
comparison src/EM.cpp @ 127:df05f855f63b
Merge from branch bqvec-openmp
author | Chris Cannam |
---|---|
date | Wed, 07 May 2014 11:55:32 +0100 |
parents | 7377032e0bf1 |
children | f25b8e7de0ed |
comparison
equal
deleted
inserted
replaced
110:e282930cfca7 | 127:df05f855f63b |
---|---|
20 #include <cstdlib> | 20 #include <cstdlib> |
21 #include <cmath> | 21 #include <cmath> |
22 | 22 |
23 #include <iostream> | 23 #include <iostream> |
24 | 24 |
25 #include <vector> | 25 #include "VectorOps.h" |
26 #include "Allocators.h" | |
26 | 27 |
27 using std::vector; | 28 using std::vector; |
28 using std::cerr; | 29 using std::cerr; |
29 using std::endl; | 30 using std::endl; |
30 | 31 |
32 using namespace breakfastquay; | |
33 | |
31 static double epsilon = 1e-16; | 34 static double epsilon = 1e-16; |
32 | 35 |
33 EM::EM(bool useShifts) : | 36 EM::EM(bool useShifts) : |
34 m_useShifts(useShifts), | |
35 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), | 37 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), |
36 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), | 38 m_shiftCount(useShifts ? SILVET_TEMPLATE_MAX_SHIFT * 2 + 1 : 1), |
37 m_binCount(SILVET_TEMPLATE_HEIGHT), | 39 m_binCount(SILVET_TEMPLATE_HEIGHT), |
38 m_instrumentCount(SILVET_TEMPLATE_COUNT), | 40 m_sourceCount(SILVET_TEMPLATE_COUNT), |
39 m_pitchSparsity(1.1), | 41 m_pitchSparsity(1.1), |
40 m_sourceSparsity(1.3) | 42 m_sourceSparsity(1.3), |
41 { | 43 m_lowestPitch(silvet_templates_lowest_note), |
42 m_lowestPitch = silvet_templates_lowest_note; | 44 m_highestPitch(silvet_templates_highest_note) |
43 m_highestPitch = silvet_templates_highest_note; | 45 { |
44 | 46 m_pitches = allocate<double>(m_noteCount); |
45 m_pitches = V(m_noteCount); | 47 m_updatePitches = allocate<double>(m_noteCount); |
46 for (int n = 0; n < m_noteCount; ++n) { | 48 for (int n = 0; n < m_noteCount; ++n) { |
47 m_pitches[n] = drand48(); | 49 m_pitches[n] = drand48(); |
48 } | 50 } |
49 | 51 |
50 m_shifts = Grid(m_shiftCount); | 52 if (useShifts) { |
51 for (int f = 0; f < m_shiftCount; ++f) { | 53 m_shifts = allocate_channels<double>(m_shiftCount, m_noteCount); |
52 m_shifts[f] = V(m_noteCount); | 54 m_updateShifts = allocate_channels<double>(m_shiftCount, m_noteCount); |
53 for (int n = 0; n < m_noteCount; ++n) { | 55 for (int f = 0; f < m_shiftCount; ++f) { |
54 if (m_useShifts) { | 56 for (int n = 0; n < m_noteCount; ++n) { |
55 m_shifts[f][n] = drand48(); | 57 m_shifts[f][n] = drand48(); |
56 } else { | 58 } |
57 m_shifts[f][n] = 1.0; | 59 } |
58 } | 60 } else { |
59 } | 61 m_shifts = 0; |
62 m_updateShifts = 0; | |
60 } | 63 } |
61 | 64 |
62 m_sources = Grid(m_instrumentCount); | 65 m_sources = allocate_channels<double>(m_sourceCount, m_noteCount); |
63 for (int i = 0; i < m_instrumentCount; ++i) { | 66 m_updateSources = allocate_channels<double>(m_sourceCount, m_noteCount); |
64 m_sources[i] = V(m_noteCount); | 67 for (int i = 0; i < m_sourceCount; ++i) { |
65 for (int n = 0; n < m_noteCount; ++n) { | 68 for (int n = 0; n < m_noteCount; ++n) { |
66 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); | 69 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); |
67 } | 70 } |
68 } | 71 } |
69 | 72 |
70 m_estimate = V(m_binCount); | 73 m_estimate = allocate<double>(m_binCount); |
71 m_q = V(m_binCount); | 74 m_q = allocate<double>(m_binCount); |
72 } | 75 } |
73 | 76 |
74 EM::~EM() | 77 EM::~EM() |
75 { | 78 { |
79 deallocate(m_q); | |
80 deallocate(m_estimate); | |
81 deallocate_channels(m_sources, m_sourceCount); | |
82 deallocate_channels(m_updateSources, m_sourceCount); | |
83 deallocate_channels(m_shifts, m_shiftCount); | |
84 deallocate_channels(m_updateShifts, m_shiftCount); | |
85 deallocate(m_pitches); | |
86 deallocate(m_updatePitches); | |
76 } | 87 } |
77 | 88 |
78 void | 89 void |
79 EM::rangeFor(int instrument, int &minPitch, int &maxPitch) | 90 EM::rangeFor(int instrument, int &minPitch, int &maxPitch) |
80 { | 91 { |
89 rangeFor(instrument, minPitch, maxPitch); | 100 rangeFor(instrument, minPitch, maxPitch); |
90 return (pitch >= minPitch && pitch <= maxPitch); | 101 return (pitch >= minPitch && pitch <= maxPitch); |
91 } | 102 } |
92 | 103 |
93 void | 104 void |
94 EM::normaliseColumn(V &column) | 105 EM::normaliseColumn(double *column, int size) |
95 { | 106 { |
96 double sum = 0.0; | 107 double sum = v_sum(column, size); |
97 for (int i = 0; i < (int)column.size(); ++i) { | 108 v_scale(column, 1.0 / sum, size); |
98 sum += column[i]; | 109 } |
99 } | 110 |
100 for (int i = 0; i < (int)column.size(); ++i) { | 111 void |
101 column[i] /= sum; | 112 EM::normaliseGrid(double **grid, int size1, int size2) |
102 } | 113 { |
103 } | 114 double *denominators = allocate_and_zero<double>(size2); |
104 | 115 |
105 void | 116 for (int i = 0; i < size1; ++i) { |
106 EM::normaliseGrid(Grid &grid) | 117 for (int j = 0; j < size2; ++j) { |
107 { | |
108 V denominators(grid[0].size()); | |
109 | |
110 for (int i = 0; i < (int)grid.size(); ++i) { | |
111 for (int j = 0; j < (int)grid[i].size(); ++j) { | |
112 denominators[j] += grid[i][j]; | 118 denominators[j] += grid[i][j]; |
113 } | 119 } |
114 } | 120 } |
115 | 121 |
116 for (int i = 0; i < (int)grid.size(); ++i) { | 122 for (int i = 0; i < size1; ++i) { |
117 for (int j = 0; j < (int)grid[i].size(); ++j) { | 123 v_divide(grid[i], denominators, size2); |
118 grid[i][j] /= denominators[j]; | 124 } |
119 } | 125 |
120 } | 126 deallocate(denominators); |
121 } | 127 } |
122 | 128 |
123 void | 129 void |
124 EM::iterate(V column) | 130 EM::iterate(const double *column) |
125 { | 131 { |
126 normaliseColumn(column); | 132 double *norm = allocate<double>(m_binCount); |
127 expectation(column); | 133 v_copy(norm, column, m_binCount); |
128 maximisation(column); | 134 normaliseColumn(norm, m_binCount); |
129 } | 135 expectation(norm); |
130 | 136 maximisation(norm); |
131 const float * | 137 deallocate(norm); |
138 } | |
139 | |
140 const double * | |
132 EM::templateFor(int instrument, int note, int shift) | 141 EM::templateFor(int instrument, int note, int shift) |
133 { | 142 { |
134 if (m_useShifts) { | 143 if (m_shifts) { |
135 return silvet_templates[instrument].data[note] + shift; | 144 return silvet_templates[instrument].data[note] + shift; |
136 } else { | 145 } else { |
137 return silvet_templates[instrument].data[note] + | 146 return silvet_templates[instrument].data[note] + |
138 SILVET_TEMPLATE_MAX_SHIFT; | 147 SILVET_TEMPLATE_MAX_SHIFT; |
139 } | 148 } |
140 } | 149 } |
141 | 150 |
142 void | 151 void |
143 EM::expectation(const V &column) | 152 EM::expectation(const double *column) |
144 { | 153 { |
145 // cerr << "."; | 154 // cerr << "."; |
146 | 155 |
147 for (int i = 0; i < m_binCount; ++i) { | 156 v_set(m_estimate, epsilon, m_binCount); |
148 m_estimate[i] = epsilon; | 157 |
149 } | 158 for (int i = 0; i < m_sourceCount; ++i) { |
150 | 159 for (int n = 0; n < m_noteCount; ++n) { |
151 for (int i = 0; i < m_instrumentCount; ++i) { | 160 const double pitch = m_pitches[n]; |
152 for (int n = 0; n < m_noteCount; ++n) { | 161 const double source = m_sources[i][n]; |
153 for (int f = 0; f < m_shiftCount; ++f) { | 162 for (int f = 0; f < m_shiftCount; ++f) { |
154 const float *w = templateFor(i, n, f); | 163 const double *w = templateFor(i, n, f); |
155 double pitch = m_pitches[n]; | 164 const double shift = m_shifts ? m_shifts[f][n] : 1.0; |
156 double source = m_sources[i][n]; | 165 const double factor = pitch * source * shift; |
157 double shift = m_shifts[f][n]; | 166 v_add_with_gain(m_estimate, w, factor, m_binCount); |
158 for (int j = 0; j < m_binCount; ++j) { | |
159 m_estimate[j] += w[j] * pitch * source * shift; | |
160 } | |
161 } | 167 } |
162 } | 168 } |
163 } | 169 } |
164 | 170 |
165 for (int i = 0; i < m_binCount; ++i) { | 171 for (int i = 0; i < m_binCount; ++i) { |
166 m_q[i] = column[i] / m_estimate[i]; | 172 m_q[i] = column[i] / m_estimate[i]; |
167 } | 173 } |
168 } | 174 } |
169 | 175 |
170 void | 176 void |
171 EM::maximisation(const V &column) | 177 EM::maximisation(const double *column) |
172 { | 178 { |
173 V newPitches = m_pitches; | 179 v_set(m_updatePitches, epsilon, m_noteCount); |
180 | |
181 for (int i = 0; i < m_sourceCount; ++i) { | |
182 v_set(m_updateSources[i], epsilon, m_noteCount); | |
183 } | |
184 | |
185 if (m_shifts) { | |
186 for (int i = 0; i < m_shiftCount; ++i) { | |
187 v_set(m_updateShifts[i], epsilon, m_noteCount); | |
188 } | |
189 } | |
190 | |
191 double *contributions = allocate<double>(m_binCount); | |
174 | 192 |
175 for (int n = 0; n < m_noteCount; ++n) { | 193 for (int n = 0; n < m_noteCount; ++n) { |
176 newPitches[n] = epsilon; | 194 |
177 if (n >= m_lowestPitch && n <= m_highestPitch) { | 195 const double pitch = m_pitches[n]; |
178 for (int i = 0; i < m_instrumentCount; ++i) { | 196 |
179 for (int f = 0; f < m_shiftCount; ++f) { | 197 for (int f = 0; f < m_shiftCount; ++f) { |
180 const float *w = templateFor(i, n, f); | 198 |
181 double pitch = m_pitches[n]; | 199 const double shift = m_shifts ? m_shifts[f][n] : 1.0; |
182 double source = m_sources[i][n]; | 200 |
183 double shift = m_shifts[f][n]; | 201 for (int i = 0; i < m_sourceCount; ++i) { |
184 for (int j = 0; j < m_binCount; ++j) { | 202 |
185 newPitches[n] += w[j] * m_q[j] * pitch * source * shift; | 203 const double source = m_sources[i][n]; |
204 const double factor = pitch * source * shift; | |
205 const double *w = templateFor(i, n, f); | |
206 | |
207 v_copy(contributions, w, m_binCount); | |
208 v_multiply(contributions, m_q, m_binCount); | |
209 | |
210 double total = factor * v_sum(contributions, m_binCount); | |
211 | |
212 if (n >= m_lowestPitch && n <= m_highestPitch) { | |
213 | |
214 m_updatePitches[n] += total; | |
215 | |
216 if (inRange(i, n)) { | |
217 m_updateSources[i][n] += total; | |
186 } | 218 } |
187 } | 219 } |
188 } | 220 |
189 } | 221 if (m_shifts) { |
190 if (m_pitchSparsity != 1.0) { | 222 m_updateShifts[f][n] += total; |
191 newPitches[n] = pow(newPitches[n], m_pitchSparsity); | |
192 } | |
193 } | |
194 normaliseColumn(newPitches); | |
195 | |
196 Grid newShifts = m_shifts; | |
197 | |
198 if (m_useShifts) { | |
199 for (int f = 0; f < m_shiftCount; ++f) { | |
200 for (int n = 0; n < m_noteCount; ++n) { | |
201 newShifts[f][n] = epsilon; | |
202 for (int i = 0; i < m_instrumentCount; ++i) { | |
203 const float *w = templateFor(i, n, f); | |
204 double pitch = m_pitches[n]; | |
205 double source = m_sources[i][n]; | |
206 double shift = m_shifts[f][n]; | |
207 for (int j = 0; j < m_binCount; ++j) { | |
208 newShifts[f][n] += w[j] * m_q[j] * pitch * source * shift; | |
209 } | |
210 } | 223 } |
211 } | 224 } |
212 } | 225 } |
213 normaliseGrid(newShifts); | 226 } |
214 } | 227 |
215 | 228 if (m_pitchSparsity != 1.0) { |
216 Grid newSources = m_sources; | 229 for (int n = 0; n < m_noteCount; ++n) { |
217 | 230 m_updatePitches[n] = |
218 for (int i = 0; i < m_instrumentCount; ++i) { | 231 pow(m_updatePitches[n], m_pitchSparsity); |
219 for (int n = 0; n < m_noteCount; ++n) { | 232 } |
220 newSources[i][n] = epsilon; | 233 } |
221 if (inRange(i, n)) { | 234 |
222 for (int f = 0; f < m_shiftCount; ++f) { | 235 if (m_sourceSparsity != 1.0) { |
223 const float *w = templateFor(i, n, f); | 236 for (int n = 0; n < m_noteCount; ++n) { |
224 double pitch = m_pitches[n]; | 237 for (int i = 0; i < m_sourceCount; ++i) { |
225 double source = m_sources[i][n]; | 238 m_updateSources[i][n] = |
226 double shift = m_shifts[f][n]; | 239 pow(m_updateSources[i][n], m_sourceSparsity); |
227 for (int j = 0; j < m_binCount; ++j) { | 240 } |
228 newSources[i][n] += w[j] * m_q[j] * pitch * source * shift; | 241 } |
229 } | 242 } |
230 } | 243 |
231 } | 244 normaliseColumn(m_updatePitches, m_noteCount); |
232 if (m_sourceSparsity != 1.0) { | 245 std::swap(m_pitches, m_updatePitches); |
233 newSources[i][n] = pow(newSources[i][n], m_sourceSparsity); | 246 |
234 } | 247 normaliseGrid(m_updateSources, m_sourceCount, m_noteCount); |
235 } | 248 std::swap(m_sources, m_updateSources); |
236 } | 249 |
237 normaliseGrid(newSources); | 250 if (m_shifts) { |
238 | 251 normaliseGrid(m_updateShifts, m_shiftCount, m_noteCount); |
239 m_pitches = newPitches; | 252 std::swap(m_shifts, m_updateShifts); |
240 m_shifts = newShifts; | 253 } |
241 m_sources = newSources; | 254 } |
242 } | 255 |
243 | 256 |
244 |