annotate yeti/em.yeti @ 372:af71cbdab621 tip

Update bqvec code
author Chris Cannam
date Tue, 19 Nov 2019 10:13:32 +0000
parents cd9fd74931bb
children
rev   line source
Chris@13 1
Chris@13 2 module em;
Chris@13 3
Chris@13 4 mm = load may.mathmisc;
Chris@13 5 vec = load may.vector;
Chris@13 6 mat = load may.matrix;
Chris@13 7
Chris@26 8 inRange ranges inst note =
Chris@26 9 note >= ranges[inst].lowest and note <= ranges[inst].highest;
Chris@26 10
Chris@26 11 normaliseColumn v =
Chris@26 12 (s = vec.sum v;
Chris@26 13 if s > 0 then vec.divideBy s v
Chris@26 14 else v
Chris@26 15 fi);
Chris@26 16
Chris@26 17 normaliseChunk =
Chris@26 18 mat.mapColumns normaliseColumn;
Chris@26 19
Chris@26 20 normaliseSources sourceMatrices =
Chris@26 21 (denom = fold do acc source: mat.sum [acc, source] done
Chris@26 22 (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices;
Chris@26 23 array
Chris@26 24 (map do source: mat.entryWiseDivide source denom done sourceMatrices));
Chris@14 25
Chris@14 26 initialise ranges templates notes size =
Chris@26 27 (instrumentNames = sort (keys ranges);
Chris@26 28 ranges = array (map (at ranges) instrumentNames);
Chris@13 29 {
Chris@14 30 pitches = // z in the original. 1xN per note
Chris@26 31 normaliseChunk
Chris@26 32 (mat.randomMatrix { rows = notes, columns = size.columns }),
Chris@26 33 sources = // u in the original. a tensor, 1xN per note-instrument
Chris@26 34 normaliseSources
Chris@26 35 (array
Chris@26 36 (map do inst:
Chris@26 37 mat.tiledTo { rows = notes, columns = size.columns }
Chris@26 38 (mat.newColumnVector
Chris@26 39 (vec.fromList
Chris@26 40 (map do note:
Chris@26 41 if inRange ranges inst note
Chris@26 42 then 1
Chris@26 43 else 0
Chris@26 44 fi
Chris@26 45 done [0..notes-1])))
Chris@26 46 done [0..length instrumentNames - 1])),
Chris@26 47 instrumentNames,
Chris@26 48 nInstruments = length instrumentNames,
Chris@26 49 nNotes = notes,
Chris@14 50 ranges,
Chris@26 51 templates = array
Chris@26 52 (map do iname:
Chris@26 53 normaliseChunk templates[iname];
Chris@26 54 done instrumentNames),
Chris@26 55 lowest = head (sort (map do r: r.lowest done ranges)),
Chris@26 56 highest = head (reverse (sort (map do r: r.highest done ranges))),
Chris@14 57 });
Chris@14 58
Chris@14 59 epsilon = 1e-16;
Chris@14 60
Chris@14 61 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
Chris@14 62
Chris@26 63 distributionsFor data inst note =
Chris@26 64 {
Chris@26 65 w = mat.newColumnVector (mat.getColumn note data.templates[inst]),
Chris@26 66 p = mat.newRowVector (mat.getRow note data.pitches),
Chris@26 67 s = mat.newRowVector (mat.getRow note data.sources[inst]),
Chris@26 68 };
Chris@26 69
Chris@14 70 performExpectation data chunk =
Chris@14 71 (estimate =
Chris@26 72 fold do acc inst:
Chris@14 73 fold do acc note:
Chris@26 74 { w, p, s } = distributionsFor data inst note;
Chris@26 75 mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ];
Chris@26 76 done acc [data.ranges[inst].lowest ..
Chris@26 77 data.ranges[inst].highest]
Chris@26 78 done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1];
Chris@26 79 { estimate, q = mat.entryWiseDivide chunk estimate});
Chris@14 80
Chris@26 81 performMaximisation data chunk q =
Chris@26 82 (chunk = normaliseChunk chunk;
Chris@26 83 columns = mat.width chunk;
Chris@26 84 e = mat.constMatrix epsilon { rows = 1, columns };
Chris@15 85
Chris@28 86 noteInstrumentProducts = [:];
Chris@28 87
Chris@26 88 pitches = mat.concatVertical
Chris@26 89 (map do note:
Chris@26 90 if note < data.lowest or note > data.highest then e else
Chris@26 91 fold do acc inst:
Chris@26 92 { w, p, s } = distributionsFor data inst note;
Chris@28 93 prod =
Chris@28 94 fold do acc bin:
Chris@28 95 mat.sum
Chris@28 96 [acc,
Chris@28 97 mat.scaled (mat.at w bin 0)
Chris@28 98 (mat.entryWiseProduct
Chris@28 99 [p, s, mat.newRowVector (mat.getRow bin chunk)])]
Chris@28 100 done e [0..mat.height chunk - 1];
Chris@28 101 noteInstrumentProducts[{inst,note}] := prod;
Chris@28 102 mat.sum [acc, prod];
Chris@26 103 done e [0..data.nInstruments-1]
Chris@26 104 fi;
Chris@26 105 done [0..data.nNotes - 1]);
Chris@26 106 pitches = normaliseChunk pitches;
Chris@26 107
Chris@26 108 sources = array
Chris@26 109 (map do inst:
Chris@26 110 (mat.concatVertical
Chris@26 111 (map do note:
Chris@26 112 if not inRange data.ranges inst note then e else
Chris@28 113 noteInstrumentProducts[{inst,note}]
Chris@26 114 fi
Chris@26 115 done [0..data.nNotes - 1]))
Chris@26 116 done [0..data.nInstruments-1]);
Chris@26 117 sources = normaliseSources sources;
Chris@26 118
Chris@26 119 data with {
Chris@26 120 pitches,
Chris@26 121 sources,
Chris@26 122 });
Chris@26 123
Chris@26 124
Chris@13 125
Chris@13 126 {
Chris@14 127 initialise,
Chris@14 128 performExpectation,
Chris@14 129 performMaximisation,
Chris@13 130 }
Chris@13 131
Chris@13 132