Chris@13: Chris@13: module em; Chris@13: Chris@13: mm = load may.mathmisc; Chris@13: vec = load may.vector; Chris@13: mat = load may.matrix; Chris@13: Chris@26: inRange ranges inst note = Chris@26: note >= ranges[inst].lowest and note <= ranges[inst].highest; Chris@26: Chris@26: normaliseColumn v = Chris@26: (s = vec.sum v; Chris@26: if s > 0 then vec.divideBy s v Chris@26: else v Chris@26: fi); Chris@26: Chris@26: normaliseChunk = Chris@26: mat.mapColumns normaliseColumn; Chris@26: Chris@26: normaliseSources sourceMatrices = Chris@26: (denom = fold do acc source: mat.sum [acc, source] done Chris@26: (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices; Chris@26: array Chris@26: (map do source: mat.entryWiseDivide source denom done sourceMatrices)); Chris@14: Chris@14: initialise ranges templates notes size = Chris@26: (instrumentNames = sort (keys ranges); Chris@26: ranges = array (map (at ranges) instrumentNames); Chris@13: { Chris@14: pitches = // z in the original. 1xN per note Chris@26: normaliseChunk Chris@26: (mat.randomMatrix { rows = notes, columns = size.columns }), Chris@26: sources = // u in the original. a tensor, 1xN per note-instrument Chris@26: normaliseSources Chris@26: (array Chris@26: (map do inst: Chris@26: mat.tiledTo { rows = notes, columns = size.columns } Chris@26: (mat.newColumnVector Chris@26: (vec.fromList Chris@26: (map do note: Chris@26: if inRange ranges inst note Chris@26: then 1 Chris@26: else 0 Chris@26: fi Chris@26: done [0..notes-1]))) Chris@26: done [0..length instrumentNames - 1])), Chris@26: instrumentNames, Chris@26: nInstruments = length instrumentNames, Chris@26: nNotes = notes, Chris@14: ranges, Chris@26: templates = array Chris@26: (map do iname: Chris@26: normaliseChunk templates[iname]; Chris@26: done instrumentNames), Chris@26: lowest = head (sort (map do r: r.lowest done ranges)), Chris@26: highest = head (reverse (sort (map do r: r.highest done ranges))), Chris@14: }); Chris@14: Chris@14: epsilon = 1e-16; Chris@14: Chris@14: select predicate = concatMap do v: if predicate v then [v] else [] fi done; Chris@14: Chris@26: distributionsFor data inst note = Chris@26: { Chris@26: w = mat.newColumnVector (mat.getColumn note data.templates[inst]), Chris@26: p = mat.newRowVector (mat.getRow note data.pitches), Chris@26: s = mat.newRowVector (mat.getRow note data.sources[inst]), Chris@26: }; Chris@26: Chris@14: performExpectation data chunk = Chris@14: (estimate = Chris@26: fold do acc inst: Chris@14: fold do acc note: Chris@26: { w, p, s } = distributionsFor data inst note; Chris@26: mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ]; Chris@26: done acc [data.ranges[inst].lowest .. Chris@26: data.ranges[inst].highest] Chris@26: done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1]; Chris@26: { estimate, q = mat.entryWiseDivide chunk estimate}); Chris@14: Chris@26: performMaximisation data chunk q = Chris@26: (chunk = normaliseChunk chunk; Chris@26: columns = mat.width chunk; Chris@26: e = mat.constMatrix epsilon { rows = 1, columns }; Chris@15: Chris@28: noteInstrumentProducts = [:]; Chris@28: Chris@26: pitches = mat.concatVertical Chris@26: (map do note: Chris@26: if note < data.lowest or note > data.highest then e else Chris@26: fold do acc inst: Chris@26: { w, p, s } = distributionsFor data inst note; Chris@28: prod = Chris@28: fold do acc bin: Chris@28: mat.sum Chris@28: [acc, Chris@28: mat.scaled (mat.at w bin 0) Chris@28: (mat.entryWiseProduct Chris@28: [p, s, mat.newRowVector (mat.getRow bin chunk)])] Chris@28: done e [0..mat.height chunk - 1]; Chris@28: noteInstrumentProducts[{inst,note}] := prod; Chris@28: mat.sum [acc, prod]; Chris@26: done e [0..data.nInstruments-1] Chris@26: fi; Chris@26: done [0..data.nNotes - 1]); Chris@26: pitches = normaliseChunk pitches; Chris@26: Chris@26: sources = array Chris@26: (map do inst: Chris@26: (mat.concatVertical Chris@26: (map do note: Chris@26: if not inRange data.ranges inst note then e else Chris@28: noteInstrumentProducts[{inst,note}] Chris@26: fi Chris@26: done [0..data.nNotes - 1])) Chris@26: done [0..data.nInstruments-1]); Chris@26: sources = normaliseSources sources; Chris@26: Chris@26: data with { Chris@26: pitches, Chris@26: sources, Chris@26: }); Chris@26: Chris@26: Chris@13: Chris@13: { Chris@14: initialise, Chris@14: performExpectation, Chris@14: performMaximisation, Chris@13: } Chris@13: Chris@13: