Chris@13
|
1
|
Chris@13
|
2 module em;
|
Chris@13
|
3
|
Chris@13
|
4 mm = load may.mathmisc;
|
Chris@13
|
5 vec = load may.vector;
|
Chris@13
|
6 mat = load may.matrix;
|
Chris@13
|
7
|
Chris@14
|
8 inRange ranges instrument note =
|
Chris@14
|
9 note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
|
Chris@14
|
10
|
Chris@14
|
11 initialise ranges templates notes size =
|
Chris@14
|
12 (instruments = keys ranges;
|
Chris@13
|
13 {
|
Chris@14
|
14 pitches = // z in the original. 1xN per note
|
Chris@13
|
15 array (map do note:
|
Chris@14
|
16 mat.randomMatrix { rows = 1, columns = size.columns }
|
Chris@13
|
17 done [0..notes-1]),
|
Chris@14
|
18 sources = // u in the original. 1xN per note-instrument
|
Chris@14
|
19 mapIntoHash id
|
Chris@13
|
20 do instrument:
|
Chris@13
|
21 array (map do note:
|
Chris@14
|
22 mat.constMatrix
|
Chris@14
|
23 (if inRange ranges instrument note then 1 else 0 fi)
|
Chris@14
|
24 (size with { rows = 1 })
|
Chris@13
|
25 done [0..notes-1])
|
Chris@14
|
26 done instruments,
|
Chris@14
|
27 instruments,
|
Chris@14
|
28 templates,
|
Chris@14
|
29 ranges,
|
Chris@14
|
30 lowest = head (sort (map do i: ranges[i].lowest done instruments)),
|
Chris@14
|
31 highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
|
Chris@14
|
32 });
|
Chris@14
|
33
|
Chris@14
|
34 epsilon = 1e-16;
|
Chris@14
|
35
|
Chris@14
|
36 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
|
Chris@14
|
37
|
Chris@14
|
38 performExpectation data chunk =
|
Chris@14
|
39 (estimate =
|
Chris@14
|
40 fold do acc instrument:
|
Chris@14
|
41 fold do acc note:
|
Chris@14
|
42 template = mat.getColumn note data.templates[instrument];
|
Chris@15
|
43 resize = mat.tiledTo (mat.size chunk);
|
Chris@15
|
44 w = resize (mat.newColumnVector template);
|
Chris@15
|
45 p = resize data.pitches[note];
|
Chris@15
|
46 s = resize data.sources[instrument][note];
|
Chris@14
|
47 mat.sum [acc, mat.entryWiseProduct [w, p, s]];
|
Chris@14
|
48 done acc [data.ranges[instrument].lowest ..
|
Chris@14
|
49 data.ranges[instrument].highest]
|
Chris@14
|
50 done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
|
Chris@14
|
51 mat.entryWiseDivide chunk estimate);
|
Chris@14
|
52
|
Chris@14
|
53 performMaximisation data chunk error =
|
Chris@15
|
54 (pitches =
|
Chris@15
|
55 fold do acc note:
|
Chris@15
|
56 fold do acc instrument:
|
Chris@15
|
57 // want sum of error * original for all template and instruments
|
Chris@15
|
58 // for this pitch, divided by sum of error * original for all
|
Chris@15
|
59 // template and instruments for all pitches
|
Chris@15
|
60
|
Chris@14
|
61 template = mat.getColumn note data.templates[instrument];
|
Chris@14
|
62 w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
|
Chris@14
|
63 p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
|
Chris@14
|
64 s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
|
Chris@15
|
65 mat.sum [acc, mat.entryWiseProduct [w, p, s, error]]
|
Chris@14
|
66 done acc (select do i: inRange data.ranges i note done data.instruments)
|
Chris@17
|
67 done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest];
|
Chris@17
|
68 pitches);
|
Chris@13
|
69
|
Chris@13
|
70 {
|
Chris@14
|
71 initialise,
|
Chris@14
|
72 performExpectation,
|
Chris@14
|
73 performMaximisation,
|
Chris@13
|
74 }
|
Chris@13
|
75
|
Chris@13
|
76
|