Mercurial > hg > silvet
diff yeti/em.yeti @ 14:a91de434feb8
More (sometimes baffled) annotations and a bit of work on the EM
author | Chris Cannam |
---|---|
date | Mon, 24 Mar 2014 16:31:20 +0000 |
parents | e15bc63cb146 |
children | 2b7257e4fc8a |
line wrap: on
line diff
--- a/yeti/em.yeti Fri Mar 21 18:12:38 2014 +0000 +++ b/yeti/em.yeti Mon Mar 24 16:31:20 2014 +0000 @@ -5,27 +5,68 @@ vec = load may.vector; mat = load may.matrix; -initialiseEM ranges notes size = +inRange ranges instrument note = + note >= ranges[instrument].lowest and note <= ranges[instrument].highest; + +initialise ranges templates notes size = + (instruments = keys ranges; { - pitches = // z in the original + pitches = // z in the original. 1xN per note array (map do note: - map \(mm.random ()) [0..size.columns-1] + mat.randomMatrix { rows = 1, columns = size.columns } done [0..notes-1]), - sources = - mapIntoHash id // u in the original + sources = // u in the original. 1xN per note-instrument + mapIntoHash id do instrument: array (map do note: - if note >= ranges[instrument].lowestNote and - note <= ranges[instrument].highestNote - then vec.ones size.columns - else vec.zeros size.columns - fi + mat.constMatrix + (if inRange ranges instrument note then 1 else 0 fi) + (size with { rows = 1 }) done [0..notes-1]) - done (keys ranges); - }; + done instruments, + instruments, + templates, + ranges, + lowest = head (sort (map do i: ranges[i].lowest done instruments)), + highest = head (reverse (sort (map do i: ranges[i].highest done instruments))), + }); + +epsilon = 1e-16; + +select predicate = concatMap do v: if predicate v then [v] else [] fi done; + +performExpectation data chunk = + (estimate = + fold do acc instrument: + fold do acc note: + template = mat.getColumn note data.templates[instrument]; + w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template); + p = mat.repeatedVertical (mat.height chunk) data.pitches[note]; + s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note]; + mat.sum [acc, mat.entryWiseProduct [w, p, s]]; + done acc [data.ranges[instrument].lowest .. + data.ranges[instrument].highest] + done (mat.constMatrix epsilon (mat.size chunk)) data.instruments; + mat.entryWiseDivide chunk estimate); + +performMaximisation data chunk error = + (fold do acc note: + fold do acc instrument: + template = mat.getColumn note data.templates[instrument]; + w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template); + p = mat.repeatedVertical (mat.height chunk) data.pitches[note]; + s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note]; + + + mat.sum [acc, mat.entryWiseProduct [w, s, error]] + + done acc (select do i: inRange data.ranges i note done data.instruments) + done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest]); { - initialiseEM + initialise, + performExpectation, + performMaximisation, }