comparison src/EM.cpp @ 36:74b77a4d6552

Fill out the EM logic
author Chris Cannam
date Fri, 04 Apr 2014 17:48:06 +0100
parents 461d94ed3816
children 947996aac974
comparison
equal deleted inserted replaced
35:461d94ed3816 36:74b77a4d6552
15 15
16 #include "EM.h" 16 #include "EM.h"
17 17
18 #include "data/include/templates.h" 18 #include "data/include/templates.h"
19 19
20 #include <cstdlib>
21
22 #include <iostream>
23
24 #include <vector>
25
26 using std::vector;
27 using std::cerr;
28 using std::endl;
29
20 static double epsilon = 1e-16; 30 static double epsilon = 1e-16;
21 31
22 EM::EM() : 32 EM::EM() :
23 m_notes(SILVET_TEMPLATE_NOTE_COUNT), 33 m_notes(SILVET_TEMPLATE_NOTE_COUNT),
24 m_bins(SILVET_TEMPLATE_HEIGHT), 34 m_bins(SILVET_TEMPLATE_HEIGHT),
25 m_instruments(SILVET_TEMPLATE_COUNT) 35 m_instruments(SILVET_TEMPLATE_COUNT)
26 { 36 {
37 cerr << "init!" << endl;
38
27 m_lowest = 0; 39 m_lowest = 0;
28 m_highest = m_notes - 1; 40 m_highest = m_notes - 1;
29 41
30 for (int i = 0; i < m_instruments; ++i) { 42 for (int i = 0; i < m_instruments; ++i) {
31 if (i == 0 || silvet_templates[i].lowest < m_lowest) { 43 if (i == 0 || silvet_templates[i].lowest < m_lowest) {
36 } 48 }
37 } 49 }
38 50
39 m_pitches = V(m_notes); 51 m_pitches = V(m_notes);
40 52
41 for (int n = 0; n < m_notes; ++i) { 53 for (int n = 0; n < m_notes; ++n) {
42 m_pitches[n] = drand48(); 54 m_pitches[n] = drand48();
43 } 55 }
44 56
45 m_sources = Grid(m_instruments); 57 m_sources = Grid(m_instruments);
46 58
49 for (int n = 0; n < m_notes; ++n) { 61 for (int n = 0; n < m_notes; ++n) {
50 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0); 62 m_sources[i][n] = (inRange(i, n) ? 1.0 : 0.0);
51 } 63 }
52 } 64 }
53 65
66 m_estimate = V(m_bins);
54 m_q = V(m_bins); 67 m_q = V(m_bins);
55 68
56 for (int w = 0; w < m_bins; ++w) { 69 cerr << "(init done)" << endl;
57 m_q[w] = epsilon;
58 }
59 } 70 }
60 71
61 EM::~EM() 72 EM::~EM()
62 { 73 {
63 } 74 }
67 { 78 {
68 return (note >= silvet_templates[instrument].lowest && 79 return (note >= silvet_templates[instrument].lowest &&
69 note <= silvet_templates[instrument].highest); 80 note <= silvet_templates[instrument].highest);
70 } 81 }
71 82
83 void
84 EM::normalise(V &column)
85 {
86 double sum = 0.0;
87 for (int i = 0; i < (int)column.size(); ++i) {
88 sum += column[i];
89 }
90 for (int i = 0; i < (int)column.size(); ++i) {
91 column[i] /= sum;
92 }
93 }
94
95 void
96 EM::iterate(V column)
97 {
98 normalise(column);
99 expectation(column);
100 maximisation(column);
101 }
102
103 void
104 EM::expectation(const V &column)
105 {
106 cerr << ".";
107
108 for (int i = 0; i < m_bins; ++i) {
109 m_estimate[i] = epsilon;
110 }
111
112 for (int i = 0; i < m_instruments; ++i) {
113 for (int n = 0; n < m_notes; ++n) {
114 float *w = silvet_templates[i].data[n];
115 double pitch = m_pitches[n];
116 double source = m_sources[i][n];
117 for (int j = 0; j < m_bins; ++j) {
118 m_estimate[j] += w[j] * pitch * source;
119 }
120 }
121 }
122
123 for (int i = 0; i < m_bins; ++i) {
124 m_q[i] = column[i] / m_estimate[i];
125 }
126 }
127
128 void
129 EM::maximisation(const V &column)
130 {
131 V newPitches = m_pitches;
132
133 for (int n = 0; n < m_notes; ++n) {
134 newPitches[n] = epsilon;
135 if (n >= m_lowest && n <= m_highest) {
136 for (int i = 0; i < m_instruments; ++i) {
137 float *w = silvet_templates[i].data[n];
138 double pitch = m_pitches[n];
139 double source = m_sources[i][n];
140 for (int j = 0; j < m_bins; ++j) {
141 newPitches[n] += w[j] * m_q[j] * pitch * source;
142 }
143 }
144 }
145 }
146 normalise(newPitches);
147
148 Grid newSources = m_sources;
149
150 for (int i = 0; i < m_instruments; ++i) {
151 for (int n = 0; n < m_notes; ++n) {
152 newSources[i][n] = epsilon;
153 if (inRange(i, n)) {
154 float *w = silvet_templates[i].data[n];
155 for (int j = 0; j < m_bins; ++j) {
156 newSources[i][n] +=
157 w[j] * m_q[j] * m_pitches[n] * m_sources[i][n];
158 }
159 }
160 }
161 normalise(newSources[i]);
162 }
163
164 m_pitches = newPitches;
165 m_sources = newSources;
166 }
167
168 void
169 EM::report()
170 {
171 vector<int> sounding;
172 for (int n = 0; n < m_notes; ++n) {
173 if (m_pitches[n] > 0.05) {
174 sounding.push_back(n);
175 }
176 }
177 cerr << " sounding: ";
178 for (int i = 0; i < (int)sounding.size(); ++i) {
179 cerr << sounding[i] << " ";
180 }
181 cerr << endl;
182 }
183