Mercurial > hg > silvet
comparison src/EM.cpp @ 55:384338fa460d preshift
Support shifts as an additional dimension (as in the original model). Also return velocity as well.
author | Chris Cannam |
---|---|
date | Tue, 08 Apr 2014 13:30:32 +0100 |
parents | a54df67e607e |
children | 3e7e3c610fae |
comparison
equal
deleted
inserted
replaced
54:a54df67e607e | 55:384338fa460d |
---|---|
31 static double epsilon = 1e-16; | 31 static double epsilon = 1e-16; |
32 | 32 |
33 EM::EM() : | 33 EM::EM() : |
34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), | 34 m_noteCount(SILVET_TEMPLATE_NOTE_COUNT), |
35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1), | 35 m_shiftCount(SILVET_TEMPLATE_MAX_SHIFT * 2 + 1), |
36 m_pitchCount(m_noteCount * m_shiftCount), | |
37 m_binCount(SILVET_TEMPLATE_HEIGHT), | 36 m_binCount(SILVET_TEMPLATE_HEIGHT), |
38 m_instrumentCount(SILVET_TEMPLATE_COUNT), | 37 m_instrumentCount(SILVET_TEMPLATE_COUNT), |
39 m_pitchSparsity(1.1), | 38 m_pitchSparsity(1.1), |
40 m_sourceSparsity(1.3) | 39 m_sourceSparsity(1.3) |
41 { | 40 { |
42 m_lowestPitch = | 41 m_lowestPitch = silvet_templates_lowest_note; |
43 silvet_templates_lowest_note * m_shiftCount; | 42 m_highestPitch = silvet_templates_highest_note; |
44 m_highestPitch = | 43 |
45 silvet_templates_highest_note * m_shiftCount + m_shiftCount - 1; | 44 m_pitches = V(m_noteCount); |
46 | 45 for (int n = 0; n < m_noteCount; ++n) { |
47 m_pitches = V(m_pitchCount); | |
48 | |
49 for (int n = 0; n < m_pitchCount; ++n) { | |
50 m_pitches[n] = drand48(); | 46 m_pitches[n] = drand48(); |
47 } | |
48 | |
49 m_shifts = Grid(m_shiftCount); | |
50 for (int f = 0; f < m_shiftCount; ++f) { | |
51 m_shifts[f] = V(m_noteCount); | |
52 for (int n = 0; n < m_noteCount; ++n) { | |
53 m_shifts[f][n] = drand48(); | |
54 } | |
51 } | 55 } |
52 | 56 |
53 m_sources = Grid(m_instrumentCount); | 57 m_sources = Grid(m_instrumentCount); |
54 | 58 for (int i = 0; i < m_instrumentCount; ++i) { |
55 for (int i = 0; i < m_instrumentCount; ++i) { | 59 m_sources[i] = V(m_noteCount); |
56 m_sources[i] = V(m_pitchCount); | 60 for (int n = 0; n < m_noteCount; ++n) { |
57 for (int n = 0; n < m_pitchCount; ++n) { | |
58 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); | 61 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); |
59 } | 62 } |
60 } | 63 } |
61 | 64 |
62 m_estimate = V(m_binCount); | 65 m_estimate = V(m_binCount); |
68 } | 71 } |
69 | 72 |
70 void | 73 void |
71 EM::rangeFor(int instrument, int &minPitch, int &maxPitch) | 74 EM::rangeFor(int instrument, int &minPitch, int &maxPitch) |
72 { | 75 { |
73 minPitch = silvet_templates[instrument].lowest * m_shiftCount; | 76 minPitch = silvet_templates[instrument].lowest; |
74 maxPitch = silvet_templates[instrument].highest * m_shiftCount | 77 maxPitch = silvet_templates[instrument].highest; |
75 + m_shiftCount - 1; | |
76 } | 78 } |
77 | 79 |
78 bool | 80 bool |
79 EM::inRange(int instrument, int pitch) | 81 EM::inRange(int instrument, int pitch) |
80 { | 82 { |
82 rangeFor(instrument, minPitch, maxPitch); | 84 rangeFor(instrument, minPitch, maxPitch); |
83 return (pitch >= minPitch && pitch <= maxPitch); | 85 return (pitch >= minPitch && pitch <= maxPitch); |
84 } | 86 } |
85 | 87 |
86 void | 88 void |
87 EM::normalise(V &column) | 89 EM::normaliseColumn(V &column) |
88 { | 90 { |
89 double sum = 0.0; | 91 double sum = 0.0; |
90 for (int i = 0; i < (int)column.size(); ++i) { | 92 for (int i = 0; i < (int)column.size(); ++i) { |
91 sum += column[i]; | 93 sum += column[i]; |
92 } | 94 } |
94 column[i] /= sum; | 96 column[i] /= sum; |
95 } | 97 } |
96 } | 98 } |
97 | 99 |
98 void | 100 void |
99 EM::normaliseSources(Grid &sources) | 101 EM::normaliseGrid(Grid &grid) |
100 { | 102 { |
101 V denominators(sources[0].size()); | 103 V denominators(grid[0].size()); |
102 | 104 |
103 for (int i = 0; i < (int)sources.size(); ++i) { | 105 for (int i = 0; i < (int)grid.size(); ++i) { |
104 for (int j = 0; j < (int)sources[i].size(); ++j) { | 106 for (int j = 0; j < (int)grid[i].size(); ++j) { |
105 denominators[j] += sources[i][j]; | 107 denominators[j] += grid[i][j]; |
106 } | 108 } |
107 } | 109 } |
108 | 110 |
109 for (int i = 0; i < (int)sources.size(); ++i) { | 111 for (int i = 0; i < (int)grid.size(); ++i) { |
110 for (int j = 0; j < (int)sources[i].size(); ++j) { | 112 for (int j = 0; j < (int)grid[i].size(); ++j) { |
111 sources[i][j] /= denominators[j]; | 113 grid[i][j] /= denominators[j]; |
112 } | 114 } |
113 } | 115 } |
114 } | 116 } |
115 | 117 |
116 void | 118 void |
117 EM::iterate(V column) | 119 EM::iterate(V column) |
118 { | 120 { |
119 normalise(column); | 121 normaliseColumn(column); |
120 expectation(column); | 122 expectation(column); |
121 maximisation(column); | 123 maximisation(column); |
122 } | 124 } |
123 | 125 |
124 const float * | 126 const float * |
125 EM::templateFor(int instrument, int pitch) | 127 EM::templateFor(int instrument, int note, int shift) |
126 { | 128 { |
127 int note = pitch / m_shiftCount; | |
128 int shift = pitch % m_shiftCount; | |
129 return silvet_templates[instrument].data[note] + shift; | 129 return silvet_templates[instrument].data[note] + shift; |
130 } | 130 } |
131 | 131 |
132 void | 132 void |
133 EM::expectation(const V &column) | 133 EM::expectation(const V &column) |
137 for (int i = 0; i < m_binCount; ++i) { | 137 for (int i = 0; i < m_binCount; ++i) { |
138 m_estimate[i] = epsilon; | 138 m_estimate[i] = epsilon; |
139 } | 139 } |
140 | 140 |
141 for (int i = 0; i < m_instrumentCount; ++i) { | 141 for (int i = 0; i < m_instrumentCount; ++i) { |
142 for (int n = 0; n < m_pitchCount; ++n) { | 142 for (int n = 0; n < m_noteCount; ++n) { |
143 const float *w = templateFor(i, n); | 143 for (int f = 0; f < m_shiftCount; ++f) { |
144 double pitch = m_pitches[n]; | 144 const float *w = templateFor(i, n, f); |
145 double source = m_sources[i][n]; | 145 double pitch = m_pitches[n]; |
146 for (int j = 0; j < m_binCount; ++j) { | 146 double source = m_sources[i][n]; |
147 m_estimate[j] += w[j] * pitch * source; | 147 double shift = m_shifts[f][n]; |
148 for (int j = 0; j < m_binCount; ++j) { | |
149 m_estimate[j] += w[j] * pitch * source * shift; | |
150 } | |
148 } | 151 } |
149 } | 152 } |
150 } | 153 } |
151 | 154 |
152 for (int i = 0; i < m_binCount; ++i) { | 155 for (int i = 0; i < m_binCount; ++i) { |
157 void | 160 void |
158 EM::maximisation(const V &column) | 161 EM::maximisation(const V &column) |
159 { | 162 { |
160 V newPitches = m_pitches; | 163 V newPitches = m_pitches; |
161 | 164 |
162 for (int n = 0; n < m_pitchCount; ++n) { | 165 for (int n = 0; n < m_noteCount; ++n) { |
163 newPitches[n] = epsilon; | 166 newPitches[n] = epsilon; |
164 if (n >= m_lowestPitch && n <= m_highestPitch) { | 167 if (n >= m_lowestPitch && n <= m_highestPitch) { |
165 for (int i = 0; i < m_instrumentCount; ++i) { | 168 for (int i = 0; i < m_instrumentCount; ++i) { |
166 const float *w = templateFor(i, n); | 169 for (int f = 0; f < m_shiftCount; ++f) { |
170 const float *w = templateFor(i, n, f); | |
171 double pitch = m_pitches[n]; | |
172 double source = m_sources[i][n]; | |
173 double shift = m_shifts[f][n]; | |
174 for (int j = 0; j < m_binCount; ++j) { | |
175 newPitches[n] += w[j] * m_q[j] * pitch * source * shift; | |
176 } | |
177 } | |
178 } | |
179 } | |
180 if (m_pitchSparsity != 1.0) { | |
181 newPitches[n] = pow(newPitches[n], m_pitchSparsity); | |
182 } | |
183 } | |
184 normaliseColumn(newPitches); | |
185 | |
186 Grid newShifts = m_shifts; | |
187 | |
188 for (int f = 0; f < m_shiftCount; ++f) { | |
189 for (int n = 0; n < m_noteCount; ++n) { | |
190 newShifts[f][n] = epsilon; | |
191 for (int i = 0; i < m_instrumentCount; ++i) { | |
192 const float *w = templateFor(i, n, f); | |
167 double pitch = m_pitches[n]; | 193 double pitch = m_pitches[n]; |
168 double source = m_sources[i][n]; | 194 double source = m_sources[i][n]; |
195 double shift = m_shifts[f][n]; | |
169 for (int j = 0; j < m_binCount; ++j) { | 196 for (int j = 0; j < m_binCount; ++j) { |
170 newPitches[n] += w[j] * m_q[j] * pitch * source; | 197 newShifts[f][n] += w[j] * m_q[j] * pitch * source * shift; |
171 } | 198 } |
172 } | 199 } |
173 } | 200 } |
174 if (m_pitchSparsity != 1.0) { | 201 } |
175 newPitches[n] = pow(newPitches[n], m_pitchSparsity); | 202 normaliseGrid(newShifts); |
176 } | |
177 } | |
178 normalise(newPitches); | |
179 | 203 |
180 Grid newSources = m_sources; | 204 Grid newSources = m_sources; |
181 | 205 |
182 for (int i = 0; i < m_instrumentCount; ++i) { | 206 for (int i = 0; i < m_instrumentCount; ++i) { |
183 for (int n = 0; n < m_pitchCount; ++n) { | 207 for (int n = 0; n < m_noteCount; ++n) { |
184 newSources[i][n] = epsilon; | 208 newSources[i][n] = epsilon; |
185 if (inRange(i, n)) { | 209 if (inRange(i, n)) { |
186 const float *w = templateFor(i, n); | 210 for (int f = 0; f < m_shiftCount; ++f) { |
187 double pitch = m_pitches[n]; | 211 const float *w = templateFor(i, n, f); |
188 double source = m_sources[i][n]; | 212 double pitch = m_pitches[n]; |
189 for (int j = 0; j < m_binCount; ++j) { | 213 double source = m_sources[i][n]; |
190 newSources[i][n] += w[j] * m_q[j] * pitch * source; | 214 double shift = m_shifts[f][n]; |
215 for (int j = 0; j < m_binCount; ++j) { | |
216 newSources[i][n] += w[j] * m_q[j] * pitch * source * shift; | |
217 } | |
191 } | 218 } |
192 } | 219 } |
193 if (m_sourceSparsity != 1.0) { | 220 if (m_sourceSparsity != 1.0) { |
194 newSources[i][n] = pow(newSources[i][n], m_sourceSparsity); | 221 newSources[i][n] = pow(newSources[i][n], m_sourceSparsity); |
195 } | 222 } |
196 } | 223 } |
197 } | 224 } |
198 normaliseSources(newSources); | 225 normaliseGrid(newSources); |
199 | 226 |
200 m_pitches = newPitches; | 227 m_pitches = newPitches; |
228 m_shifts = newShifts; | |
201 m_sources = newSources; | 229 m_sources = newSources; |
202 } | 230 } |
203 | 231 |
204 void | 232 |
205 EM::report() | |
206 { | |
207 vector<int> sounding; | |
208 for (int n = 0; n < m_pitchCount; ++n) { | |
209 if (m_pitches[n] > 0.05) { | |
210 sounding.push_back(n); | |
211 } | |
212 } | |
213 cerr << " sounding: "; | |
214 for (int i = 0; i < (int)sounding.size(); ++i) { | |
215 cerr << sounding[i] << " "; | |
216 int maxj = -1; | |
217 double maxs = 0.0; | |
218 for (int j = 0; j < m_instrumentCount; ++j) { | |
219 if (j == 0 || m_sources[j][sounding[i]] > maxs) { | |
220 maxj = j; | |
221 maxs = m_sources[j][sounding[i]]; | |
222 } | |
223 } | |
224 cerr << silvet_templates[maxj].name << " "; | |
225 } | |
226 cerr << endl; | |
227 } | |
228 |