Chris@19
|
1
|
Chris@19
|
2 module em_onecolumn;
|
Chris@19
|
3
|
Chris@19
|
4 mm = load may.mathmisc;
|
Chris@19
|
5 vec = load may.vector;
|
Chris@19
|
6 mat = load may.matrix;
|
Chris@19
|
7
|
Chris@20
|
8 plot = load may.plot;
|
Chris@20
|
9
|
Chris@19
|
10 inRange ranges instrument note =
|
Chris@19
|
11 note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
|
Chris@19
|
12
|
Chris@22
|
13 normalise v =
|
Chris@22
|
14 (s = vec.sum v;
|
Chris@22
|
15 if s > 0 then vec.divideBy s v
|
Chris@22
|
16 else v
|
Chris@22
|
17 fi);
|
Chris@22
|
18
|
Chris@19
|
19 initialise ranges templates notes =
|
Chris@22
|
20 (instruments = sort (keys ranges);
|
Chris@19
|
21 {
|
Chris@19
|
22 pitches = // z in the original. 1 per note
|
Chris@22
|
23 normalise (vec.randoms notes),
|
Chris@19
|
24 sources = // u in the original. 1 per note-instrument
|
Chris@22
|
25 //!!! should this be normalised across instruments? i.e. row-wise
|
Chris@22
|
26 mat.fromColumns
|
Chris@22
|
27 (map do instrument:
|
Chris@19
|
28 vec.fromList
|
Chris@19
|
29 (map do note:
|
Chris@19
|
30 if inRange ranges instrument note then 1 else 0 fi
|
Chris@19
|
31 done [0..notes-1])
|
Chris@22
|
32 done instruments),
|
Chris@22
|
33 instruments = array instruments,
|
Chris@22
|
34 instCount = length instruments,
|
Chris@23
|
35 noteCount = notes,
|
Chris@22
|
36 templates = array
|
Chris@22
|
37 (map do iname:
|
Chris@22
|
38 m = templates[iname];
|
Chris@22
|
39 mat.fromColumns (map normalise (mat.asColumns m))
|
Chris@22
|
40 done instruments),
|
Chris@22
|
41 ranges = array
|
Chris@22
|
42 (map do iname:
|
Chris@22
|
43 ranges[iname]
|
Chris@22
|
44 done instruments),
|
Chris@19
|
45 lowest = head
|
Chris@22
|
46 (sort (map do iname: ranges[iname].lowest done instruments)),
|
Chris@19
|
47 highest = head (reverse
|
Chris@22
|
48 (sort (map do iname: ranges[iname].highest done instruments))),
|
Chris@19
|
49 });
|
Chris@19
|
50
|
Chris@19
|
51 epsilon = 1e-16;
|
Chris@19
|
52
|
Chris@19
|
53 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
|
Chris@19
|
54
|
Chris@22
|
55 distributionsFor data instNo note =
|
Chris@19
|
56 {
|
Chris@22
|
57 w = mat.getColumn note data.templates[instNo],
|
Chris@19
|
58 p = vec.at data.pitches note,
|
Chris@22
|
59 s = mat.at data.sources note instNo,
|
Chris@19
|
60 };
|
Chris@19
|
61
|
Chris@19
|
62 performExpectation data column =
|
Chris@22
|
63 (column = normalise column;
|
Chris@22
|
64 estimate =
|
Chris@22
|
65 fold do acc inst:
|
Chris@19
|
66 fold do acc note:
|
Chris@22
|
67 { w, p, s } = distributionsFor data inst note;
|
Chris@19
|
68 vec.add [acc, vec.scaled (p * s) w];
|
Chris@22
|
69 done acc [data.ranges[inst].lowest ..
|
Chris@22
|
70 data.ranges[inst].highest]
|
Chris@22
|
71 done (vec.consts epsilon (vec.length column)) [0..data.instCount-1];
|
Chris@19
|
72 vec.divide column estimate);
|
Chris@19
|
73
|
Chris@19
|
74 performMaximisation data column q =
|
Chris@22
|
75 (column = normalise column;
|
Chris@22
|
76
|
Chris@22
|
77 pitches = vec.fromList
|
Chris@19
|
78 (map do note:
|
Chris@23
|
79 if note >= data.lowest and note <= data.highest then
|
Chris@23
|
80 fold do acc inst:
|
Chris@23
|
81 { w, p, s } = distributionsFor data inst note;
|
Chris@23
|
82 fold do acc bin:
|
Chris@23
|
83 acc + s * p * (vec.at w bin) * (vec.at q bin);
|
Chris@23
|
84 done acc [0..vec.length column - 1]
|
Chris@23
|
85 done epsilon [0..data.instCount-1];
|
Chris@23
|
86 else epsilon
|
Chris@23
|
87 fi
|
Chris@23
|
88 done [0..data.noteCount-1]);
|
Chris@20
|
89 pitches = vec.divideBy (vec.sum pitches) pitches;
|
Chris@19
|
90
|
Chris@22
|
91 sources = mat.fromColumns
|
Chris@22
|
92 (map do inst: vec.fromList
|
Chris@20
|
93 (map do note:
|
Chris@22
|
94 if not inRange data.ranges inst note then epsilon else
|
Chris@22
|
95 { w, p, s } = distributionsFor data inst note;
|
Chris@20
|
96 fold do acc bin:
|
Chris@21
|
97 acc + s * p * (vec.at w bin) * (vec.at q bin);
|
Chris@20
|
98 done epsilon [0..vec.length column - 1]
|
Chris@20
|
99 fi;
|
Chris@23
|
100 done [0..data.noteCount-1])
|
Chris@22
|
101 done [0..data.instCount-1]);
|
Chris@20
|
102
|
Chris@22
|
103 sourceDenoms = fold do acc inst:
|
Chris@22
|
104 vec.add [acc, (mat.getColumn inst sources)]
|
Chris@23
|
105 done (vec.zeros data.noteCount) [0..data.instCount-1];
|
Chris@20
|
106
|
Chris@22
|
107 sources = mat.fromColumns
|
Chris@22
|
108 (map do inst:
|
Chris@22
|
109 vec.divide (mat.getColumn inst sources) sourceDenoms;
|
Chris@22
|
110 done [0..data.instCount-1]);
|
Chris@20
|
111
|
Chris@19
|
112 data with {
|
Chris@20
|
113 pitches,
|
Chris@20
|
114 sources,
|
Chris@19
|
115 });
|
Chris@19
|
116
|
Chris@19
|
117 {
|
Chris@19
|
118 initialise,
|
Chris@19
|
119 performExpectation,
|
Chris@19
|
120 performMaximisation,
|
Chris@19
|
121 }
|
Chris@19
|
122
|
Chris@19
|
123
|