Chris@13
|
1
|
Chris@13
|
2 module em;
|
Chris@13
|
3
|
Chris@13
|
4 mm = load may.mathmisc;
|
Chris@13
|
5 vec = load may.vector;
|
Chris@13
|
6 mat = load may.matrix;
|
Chris@13
|
7
|
Chris@14
|
8 inRange ranges instrument note =
|
Chris@14
|
9 note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
|
Chris@14
|
10
|
Chris@14
|
11 initialise ranges templates notes size =
|
Chris@14
|
12 (instruments = keys ranges;
|
Chris@13
|
13 {
|
Chris@14
|
14 pitches = // z in the original. 1xN per note
|
Chris@13
|
15 array (map do note:
|
Chris@14
|
16 mat.randomMatrix { rows = 1, columns = size.columns }
|
Chris@13
|
17 done [0..notes-1]),
|
Chris@14
|
18 sources = // u in the original. 1xN per note-instrument
|
Chris@14
|
19 mapIntoHash id
|
Chris@13
|
20 do instrument:
|
Chris@13
|
21 array (map do note:
|
Chris@14
|
22 mat.constMatrix
|
Chris@14
|
23 (if inRange ranges instrument note then 1 else 0 fi)
|
Chris@14
|
24 (size with { rows = 1 })
|
Chris@13
|
25 done [0..notes-1])
|
Chris@14
|
26 done instruments,
|
Chris@14
|
27 instruments,
|
Chris@14
|
28 templates,
|
Chris@14
|
29 ranges,
|
Chris@14
|
30 lowest = head (sort (map do i: ranges[i].lowest done instruments)),
|
Chris@14
|
31 highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
|
Chris@14
|
32 });
|
Chris@14
|
33
|
Chris@14
|
34 epsilon = 1e-16;
|
Chris@14
|
35
|
Chris@14
|
36 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
|
Chris@14
|
37
|
Chris@14
|
38 performExpectation data chunk =
|
Chris@14
|
39 (estimate =
|
Chris@14
|
40 fold do acc instrument:
|
Chris@14
|
41 fold do acc note:
|
Chris@14
|
42 template = mat.getColumn note data.templates[instrument];
|
Chris@14
|
43 w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
|
Chris@14
|
44 p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
|
Chris@14
|
45 s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
|
Chris@14
|
46 mat.sum [acc, mat.entryWiseProduct [w, p, s]];
|
Chris@14
|
47 done acc [data.ranges[instrument].lowest ..
|
Chris@14
|
48 data.ranges[instrument].highest]
|
Chris@14
|
49 done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
|
Chris@14
|
50 mat.entryWiseDivide chunk estimate);
|
Chris@14
|
51
|
Chris@14
|
52 performMaximisation data chunk error =
|
Chris@14
|
53 (fold do acc note:
|
Chris@14
|
54 fold do acc instrument:
|
Chris@14
|
55 template = mat.getColumn note data.templates[instrument];
|
Chris@14
|
56 w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
|
Chris@14
|
57 p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
|
Chris@14
|
58 s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
|
Chris@14
|
59
|
Chris@14
|
60
|
Chris@14
|
61 mat.sum [acc, mat.entryWiseProduct [w, s, error]]
|
Chris@14
|
62
|
Chris@14
|
63 done acc (select do i: inRange data.ranges i note done data.instruments)
|
Chris@14
|
64 done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest]);
|
Chris@13
|
65
|
Chris@13
|
66 {
|
Chris@14
|
67 initialise,
|
Chris@14
|
68 performExpectation,
|
Chris@14
|
69 performMaximisation,
|
Chris@13
|
70 }
|
Chris@13
|
71
|
Chris@13
|
72
|