Chris@13
|
1
|
Chris@13
|
2 module em;
|
Chris@13
|
3
|
Chris@13
|
4 mm = load may.mathmisc;
|
Chris@13
|
5 vec = load may.vector;
|
Chris@13
|
6 mat = load may.matrix;
|
Chris@13
|
7
|
Chris@26
|
8 inRange ranges inst note =
|
Chris@26
|
9 note >= ranges[inst].lowest and note <= ranges[inst].highest;
|
Chris@26
|
10
|
Chris@26
|
11 normaliseColumn v =
|
Chris@26
|
12 (s = vec.sum v;
|
Chris@26
|
13 if s > 0 then vec.divideBy s v
|
Chris@26
|
14 else v
|
Chris@26
|
15 fi);
|
Chris@26
|
16
|
Chris@26
|
17 normaliseChunk =
|
Chris@26
|
18 mat.mapColumns normaliseColumn;
|
Chris@26
|
19
|
Chris@26
|
20 normaliseSources sourceMatrices =
|
Chris@26
|
21 (denom = fold do acc source: mat.sum [acc, source] done
|
Chris@26
|
22 (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices;
|
Chris@26
|
23 array
|
Chris@26
|
24 (map do source: mat.entryWiseDivide source denom done sourceMatrices));
|
Chris@14
|
25
|
Chris@14
|
26 initialise ranges templates notes size =
|
Chris@26
|
27 (instrumentNames = sort (keys ranges);
|
Chris@26
|
28 ranges = array (map (at ranges) instrumentNames);
|
Chris@13
|
29 {
|
Chris@14
|
30 pitches = // z in the original. 1xN per note
|
Chris@26
|
31 normaliseChunk
|
Chris@26
|
32 (mat.randomMatrix { rows = notes, columns = size.columns }),
|
Chris@26
|
33 sources = // u in the original. a tensor, 1xN per note-instrument
|
Chris@26
|
34 normaliseSources
|
Chris@26
|
35 (array
|
Chris@26
|
36 (map do inst:
|
Chris@26
|
37 mat.tiledTo { rows = notes, columns = size.columns }
|
Chris@26
|
38 (mat.newColumnVector
|
Chris@26
|
39 (vec.fromList
|
Chris@26
|
40 (map do note:
|
Chris@26
|
41 if inRange ranges inst note
|
Chris@26
|
42 then 1
|
Chris@26
|
43 else 0
|
Chris@26
|
44 fi
|
Chris@26
|
45 done [0..notes-1])))
|
Chris@26
|
46 done [0..length instrumentNames - 1])),
|
Chris@26
|
47 instrumentNames,
|
Chris@26
|
48 nInstruments = length instrumentNames,
|
Chris@26
|
49 nNotes = notes,
|
Chris@14
|
50 ranges,
|
Chris@26
|
51 templates = array
|
Chris@26
|
52 (map do iname:
|
Chris@26
|
53 normaliseChunk templates[iname];
|
Chris@26
|
54 done instrumentNames),
|
Chris@26
|
55 lowest = head (sort (map do r: r.lowest done ranges)),
|
Chris@26
|
56 highest = head (reverse (sort (map do r: r.highest done ranges))),
|
Chris@14
|
57 });
|
Chris@14
|
58
|
Chris@14
|
59 epsilon = 1e-16;
|
Chris@14
|
60
|
Chris@14
|
61 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
|
Chris@14
|
62
|
Chris@26
|
63 distributionsFor data inst note =
|
Chris@26
|
64 {
|
Chris@26
|
65 w = mat.newColumnVector (mat.getColumn note data.templates[inst]),
|
Chris@26
|
66 p = mat.newRowVector (mat.getRow note data.pitches),
|
Chris@26
|
67 s = mat.newRowVector (mat.getRow note data.sources[inst]),
|
Chris@26
|
68 };
|
Chris@26
|
69
|
Chris@14
|
70 performExpectation data chunk =
|
Chris@14
|
71 (estimate =
|
Chris@26
|
72 fold do acc inst:
|
Chris@14
|
73 fold do acc note:
|
Chris@26
|
74 { w, p, s } = distributionsFor data inst note;
|
Chris@26
|
75 mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ];
|
Chris@26
|
76 done acc [data.ranges[inst].lowest ..
|
Chris@26
|
77 data.ranges[inst].highest]
|
Chris@26
|
78 done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1];
|
Chris@26
|
79 { estimate, q = mat.entryWiseDivide chunk estimate});
|
Chris@14
|
80
|
Chris@26
|
81 performMaximisation data chunk q =
|
Chris@26
|
82 (chunk = normaliseChunk chunk;
|
Chris@26
|
83 columns = mat.width chunk;
|
Chris@26
|
84 e = mat.constMatrix epsilon { rows = 1, columns };
|
Chris@15
|
85
|
Chris@28
|
86 noteInstrumentProducts = [:];
|
Chris@28
|
87
|
Chris@26
|
88 pitches = mat.concatVertical
|
Chris@26
|
89 (map do note:
|
Chris@26
|
90 if note < data.lowest or note > data.highest then e else
|
Chris@26
|
91 fold do acc inst:
|
Chris@26
|
92 { w, p, s } = distributionsFor data inst note;
|
Chris@28
|
93 prod =
|
Chris@28
|
94 fold do acc bin:
|
Chris@28
|
95 mat.sum
|
Chris@28
|
96 [acc,
|
Chris@28
|
97 mat.scaled (mat.at w bin 0)
|
Chris@28
|
98 (mat.entryWiseProduct
|
Chris@28
|
99 [p, s, mat.newRowVector (mat.getRow bin chunk)])]
|
Chris@28
|
100 done e [0..mat.height chunk - 1];
|
Chris@28
|
101 noteInstrumentProducts[{inst,note}] := prod;
|
Chris@28
|
102 mat.sum [acc, prod];
|
Chris@26
|
103 done e [0..data.nInstruments-1]
|
Chris@26
|
104 fi;
|
Chris@26
|
105 done [0..data.nNotes - 1]);
|
Chris@26
|
106 pitches = normaliseChunk pitches;
|
Chris@26
|
107
|
Chris@26
|
108 sources = array
|
Chris@26
|
109 (map do inst:
|
Chris@26
|
110 (mat.concatVertical
|
Chris@26
|
111 (map do note:
|
Chris@26
|
112 if not inRange data.ranges inst note then e else
|
Chris@28
|
113 noteInstrumentProducts[{inst,note}]
|
Chris@26
|
114 fi
|
Chris@26
|
115 done [0..data.nNotes - 1]))
|
Chris@26
|
116 done [0..data.nInstruments-1]);
|
Chris@26
|
117 sources = normaliseSources sources;
|
Chris@26
|
118
|
Chris@26
|
119 data with {
|
Chris@26
|
120 pitches,
|
Chris@26
|
121 sources,
|
Chris@26
|
122 });
|
Chris@26
|
123
|
Chris@26
|
124
|
Chris@13
|
125
|
Chris@13
|
126 {
|
Chris@14
|
127 initialise,
|
Chris@14
|
128 performExpectation,
|
Chris@14
|
129 performMaximisation,
|
Chris@13
|
130 }
|
Chris@13
|
131
|
Chris@13
|
132
|