view yeti/em.yeti @ 24:0e8ee830b5ee

Normalise source distributions row-wise (per note)
author Chris Cannam
date Mon, 31 Mar 2014 12:46:24 +0100
parents d75dd38a12a5
children fbc4011c7693
line wrap: on
line source

module em;

mm = load may.mathmisc;
vec = load may.vector;
mat = load may.matrix;

inRange ranges instrument note =
    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;

initialise ranges templates notes size =
   (instruments = keys ranges;
    {
        pitches = // z in the original. 1xN per note
            array (map do note:
                mat.randomMatrix { rows = 1, columns = size.columns }
            done [0..notes-1]),
        sources = // u in the original. 1xN per note-instrument
            mapIntoHash id
                do instrument:
                    array (map do note:
                        mat.constMatrix
                           (if inRange ranges instrument note then 1 else 0 fi)
                           (size with { rows = 1 })
                    done [0..notes-1])
                done instruments,
        instruments,
        templates,
        ranges,
        lowest = head (sort (map do i: ranges[i].lowest done instruments)),
        highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
    });

epsilon = 1e-16;

select predicate = concatMap do v: if predicate v then [v] else [] fi done;

performExpectation data chunk =
   (estimate = 
        fold do acc instrument:
            fold do acc note:
                template = mat.getColumn note data.templates[instrument];
                resize = mat.tiledTo (mat.size chunk);
                w = resize (mat.newColumnVector template);
                p = resize data.pitches[note];
                s = resize data.sources[instrument][note];
                mat.sum [acc, mat.entryWiseProduct [w, p, s]];
            done acc [data.ranges[instrument].lowest .. 
                      data.ranges[instrument].highest]
        done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
    mat.entryWiseDivide chunk estimate);

performMaximisation data chunk error =
   (pitches =
        fold do acc note:
            fold do acc instrument:
                // want sum of error * original for all template and instruments
                // for this pitch, divided by sum of error * original for all
                // template and instruments for all pitches

            template = mat.getColumn note data.templates[instrument];
            w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
            p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
            s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
            mat.sum [acc, mat.entryWiseProduct [w, p, s, error]]
        done acc (select do i: inRange data.ranges i note done data.instruments)
    done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest];
    pitches);

{
    initialise,
    performExpectation,
    performMaximisation,
}