Chris@19
|
1
|
Chris@19
|
2 module em_onecolumn;
|
Chris@19
|
3
|
Chris@19
|
4 mm = load may.mathmisc;
|
Chris@19
|
5 vec = load may.vector;
|
Chris@19
|
6 mat = load may.matrix;
|
Chris@19
|
7
|
Chris@20
|
8 plot = load may.plot;
|
Chris@20
|
9
|
Chris@19
|
10 inRange ranges instrument note =
|
Chris@19
|
11 note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
|
Chris@19
|
12
|
Chris@22
|
13 normalise v =
|
Chris@22
|
14 (s = vec.sum v;
|
Chris@22
|
15 if s > 0 then vec.divideBy s v
|
Chris@22
|
16 else v
|
Chris@22
|
17 fi);
|
Chris@22
|
18
|
Chris@25
|
19 normaliseSources s =
|
Chris@25
|
20 (denoms = fold do acc inst: vec.add [acc, (mat.getColumn inst s)] done
|
Chris@25
|
21 (vec.zeros (mat.height s)) [0..(mat.width s)-1];
|
Chris@25
|
22 mat.fromColumns
|
Chris@25
|
23 (map do inst: vec.divide (mat.getColumn inst s) denoms done
|
Chris@25
|
24 [0..(mat.width s)-1]));
|
Chris@25
|
25
|
Chris@19
|
26 initialise ranges templates notes =
|
Chris@22
|
27 (instruments = sort (keys ranges);
|
Chris@19
|
28 {
|
Chris@19
|
29 pitches = // z in the original. 1 per note
|
Chris@22
|
30 normalise (vec.randoms notes),
|
Chris@25
|
31 sources = normaliseSources // u in the original. 1 per note-instrument
|
Chris@25
|
32 (mat.fromColumns
|
Chris@25
|
33 (map do instrument:
|
Chris@25
|
34 (vec.fromList
|
Chris@25
|
35 (map do note:
|
Chris@25
|
36 if inRange ranges instrument note then 1 else 0 fi
|
Chris@25
|
37 done [0..notes-1]))
|
Chris@25
|
38 done instruments)),
|
Chris@22
|
39 instruments = array instruments,
|
Chris@22
|
40 instCount = length instruments,
|
Chris@23
|
41 noteCount = notes,
|
Chris@22
|
42 templates = array
|
Chris@22
|
43 (map do iname:
|
Chris@22
|
44 m = templates[iname];
|
Chris@22
|
45 mat.fromColumns (map normalise (mat.asColumns m))
|
Chris@22
|
46 done instruments),
|
Chris@22
|
47 ranges = array
|
Chris@22
|
48 (map do iname:
|
Chris@22
|
49 ranges[iname]
|
Chris@22
|
50 done instruments),
|
Chris@19
|
51 lowest = head
|
Chris@22
|
52 (sort (map do iname: ranges[iname].lowest done instruments)),
|
Chris@19
|
53 highest = head (reverse
|
Chris@22
|
54 (sort (map do iname: ranges[iname].highest done instruments))),
|
Chris@19
|
55 });
|
Chris@19
|
56
|
Chris@19
|
57 epsilon = 1e-16;
|
Chris@19
|
58
|
Chris@19
|
59 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
|
Chris@19
|
60
|
Chris@22
|
61 distributionsFor data instNo note =
|
Chris@19
|
62 {
|
Chris@22
|
63 w = mat.getColumn note data.templates[instNo],
|
Chris@19
|
64 p = vec.at data.pitches note,
|
Chris@22
|
65 s = mat.at data.sources note instNo,
|
Chris@19
|
66 };
|
Chris@19
|
67
|
Chris@19
|
68 performExpectation data column =
|
Chris@22
|
69 (column = normalise column;
|
Chris@22
|
70 estimate =
|
Chris@22
|
71 fold do acc inst:
|
Chris@19
|
72 fold do acc note:
|
Chris@22
|
73 { w, p, s } = distributionsFor data inst note;
|
Chris@19
|
74 vec.add [acc, vec.scaled (p * s) w];
|
Chris@22
|
75 done acc [data.ranges[inst].lowest ..
|
Chris@22
|
76 data.ranges[inst].highest]
|
Chris@22
|
77 done (vec.consts epsilon (vec.length column)) [0..data.instCount-1];
|
Chris@25
|
78 { estimate, q = vec.divide column estimate });
|
Chris@19
|
79
|
Chris@19
|
80 performMaximisation data column q =
|
Chris@22
|
81 (column = normalise column;
|
Chris@22
|
82
|
Chris@22
|
83 pitches = vec.fromList
|
Chris@19
|
84 (map do note:
|
Chris@23
|
85 if note >= data.lowest and note <= data.highest then
|
Chris@23
|
86 fold do acc inst:
|
Chris@23
|
87 { w, p, s } = distributionsFor data inst note;
|
Chris@23
|
88 fold do acc bin:
|
Chris@23
|
89 acc + s * p * (vec.at w bin) * (vec.at q bin);
|
Chris@23
|
90 done acc [0..vec.length column - 1]
|
Chris@23
|
91 done epsilon [0..data.instCount-1];
|
Chris@23
|
92 else epsilon
|
Chris@23
|
93 fi
|
Chris@23
|
94 done [0..data.noteCount-1]);
|
Chris@20
|
95 pitches = vec.divideBy (vec.sum pitches) pitches;
|
Chris@19
|
96
|
Chris@25
|
97 sources = mat.fromColumns
|
Chris@25
|
98 (map do inst: vec.fromList
|
Chris@25
|
99 (map do note:
|
Chris@25
|
100 (if not inRange data.ranges inst note then epsilon else
|
Chris@22
|
101 { w, p, s } = distributionsFor data inst note;
|
Chris@20
|
102 fold do acc bin:
|
Chris@21
|
103 acc + s * p * (vec.at w bin) * (vec.at q bin);
|
Chris@20
|
104 done epsilon [0..vec.length column - 1]
|
Chris@25
|
105 fi);
|
Chris@25
|
106 done [0..data.noteCount-1])
|
Chris@25
|
107 done [0..data.instCount-1]);
|
Chris@24
|
108
|
Chris@25
|
109 sources = normaliseSources sources;
|
Chris@24
|
110
|
Chris@19
|
111 data with {
|
Chris@20
|
112 pitches,
|
Chris@20
|
113 sources,
|
Chris@19
|
114 });
|
Chris@19
|
115
|
Chris@19
|
116 {
|
Chris@19
|
117 initialise,
|
Chris@19
|
118 performExpectation,
|
Chris@19
|
119 performMaximisation,
|
Chris@19
|
120 }
|
Chris@19
|
121
|
Chris@19
|
122
|