Mercurial > hg > silvet
diff yeti/em.yeti @ 26:fbc4011c7693
Now to try to get the matrix version working!
author | Chris Cannam |
---|---|
date | Mon, 31 Mar 2014 17:03:51 +0100 |
parents | d75dd38a12a5 |
children | cd9fd74931bb |
line wrap: on
line diff
--- a/yeti/em.yeti Mon Mar 31 15:41:15 2014 +0100 +++ b/yeti/em.yeti Mon Mar 31 17:03:51 2014 +0100 @@ -5,67 +5,125 @@ vec = load may.vector; mat = load may.matrix; -inRange ranges instrument note = - note >= ranges[instrument].lowest and note <= ranges[instrument].highest; +inRange ranges inst note = + note >= ranges[inst].lowest and note <= ranges[inst].highest; + +normaliseColumn v = + (s = vec.sum v; + if s > 0 then vec.divideBy s v + else v + fi); + +normaliseChunk = + mat.mapColumns normaliseColumn; + +normaliseSources sourceMatrices = + (denom = fold do acc source: mat.sum [acc, source] done + (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices; + array + (map do source: mat.entryWiseDivide source denom done sourceMatrices)); initialise ranges templates notes size = - (instruments = keys ranges; + (instrumentNames = sort (keys ranges); + ranges = array (map (at ranges) instrumentNames); { pitches = // z in the original. 1xN per note - array (map do note: - mat.randomMatrix { rows = 1, columns = size.columns } - done [0..notes-1]), - sources = // u in the original. 1xN per note-instrument - mapIntoHash id - do instrument: - array (map do note: - mat.constMatrix - (if inRange ranges instrument note then 1 else 0 fi) - (size with { rows = 1 }) - done [0..notes-1]) - done instruments, - instruments, - templates, + normaliseChunk + (mat.randomMatrix { rows = notes, columns = size.columns }), + sources = // u in the original. a tensor, 1xN per note-instrument + normaliseSources + (array + (map do inst: + mat.tiledTo { rows = notes, columns = size.columns } + (mat.newColumnVector + (vec.fromList + (map do note: + if inRange ranges inst note + then 1 + else 0 + fi + done [0..notes-1]))) + done [0..length instrumentNames - 1])), + instrumentNames, + nInstruments = length instrumentNames, + nNotes = notes, ranges, - lowest = head (sort (map do i: ranges[i].lowest done instruments)), - highest = head (reverse (sort (map do i: ranges[i].highest done instruments))), + templates = array + (map do iname: + normaliseChunk templates[iname]; + done instrumentNames), + lowest = head (sort (map do r: r.lowest done ranges)), + highest = head (reverse (sort (map do r: r.highest done ranges))), }); epsilon = 1e-16; select predicate = concatMap do v: if predicate v then [v] else [] fi done; +distributionsFor data inst note = + { + w = mat.newColumnVector (mat.getColumn note data.templates[inst]), + p = mat.newRowVector (mat.getRow note data.pitches), + s = mat.newRowVector (mat.getRow note data.sources[inst]), + }; + performExpectation data chunk = (estimate = - fold do acc instrument: + fold do acc inst: fold do acc note: - template = mat.getColumn note data.templates[instrument]; - resize = mat.tiledTo (mat.size chunk); - w = resize (mat.newColumnVector template); - p = resize data.pitches[note]; - s = resize data.sources[instrument][note]; - mat.sum [acc, mat.entryWiseProduct [w, p, s]]; - done acc [data.ranges[instrument].lowest .. - data.ranges[instrument].highest] - done (mat.constMatrix epsilon (mat.size chunk)) data.instruments; - mat.entryWiseDivide chunk estimate); + { w, p, s } = distributionsFor data inst note; + mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ]; + done acc [data.ranges[inst].lowest .. + data.ranges[inst].highest] + done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1]; + { estimate, q = mat.entryWiseDivide chunk estimate}); -performMaximisation data chunk error = - (pitches = - fold do acc note: - fold do acc instrument: - // want sum of error * original for all template and instruments - // for this pitch, divided by sum of error * original for all - // template and instruments for all pitches +performMaximisation data chunk q = + (chunk = normaliseChunk chunk; + columns = mat.width chunk; + e = mat.constMatrix epsilon { rows = 1, columns }; - template = mat.getColumn note data.templates[instrument]; - w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template); - p = mat.repeatedVertical (mat.height chunk) data.pitches[note]; - s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note]; - mat.sum [acc, mat.entryWiseProduct [w, p, s, error]] - done acc (select do i: inRange data.ranges i note done data.instruments) - done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest]; - pitches); + pitches = mat.concatVertical + (map do note: + if note < data.lowest or note > data.highest then e else + fold do acc inst: + { w, p, s } = distributionsFor data inst note; + fold do acc bin: + mat.sum + [acc, + mat.scaled (mat.at w bin 0) + (mat.entryWiseProduct + [p, s, mat.newRowVector (mat.getRow bin chunk)])] + done acc [0..mat.height chunk - 1] + done e [0..data.nInstruments-1] + fi; + done [0..data.nNotes - 1]); + pitches = normaliseChunk pitches; + + sources = array + (map do inst: + (mat.concatVertical + (map do note: + if not inRange data.ranges inst note then e else + { w, p, s } = distributionsFor data inst note; + fold do acc bin: + mat.sum + [acc, + mat.scaled (mat.at w bin 0) + (mat.entryWiseProduct + [p, s, mat.newRowVector (mat.getRow bin chunk)])] + done e [0..mat.height chunk - 1] + fi + done [0..data.nNotes - 1])) + done [0..data.nInstruments-1]); + sources = normaliseSources sources; + + data with { + pitches, + sources, + }); + + { initialise,