view yeti/em_onecolumn.yeti @ 20:982aa1197a7e

Getting there, slowly, sort of, with EM
author Chris Cannam
date Thu, 27 Mar 2014 11:52:07 +0000
parents f1f8c84339d0
children 8e61ec97b34e
line wrap: on
line source

module em_onecolumn;

mm = load may.mathmisc;
vec = load may.vector;
mat = load may.matrix;

plot = load may.plot;

inRange ranges instrument note =
    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;

initialise ranges templates notes =
   (instruments = keys ranges;
    {
        pitches = // z in the original. 1 per note
            vec.randoms notes,
        sources = // u in the original. 1 per note-instrument
            mapIntoHash id
                do instrument:
                    vec.fromList
                       (map do note:
                            if inRange ranges instrument note then 1 else 0 fi
                        done [0..notes-1])
                done instruments,
        instruments,
        templates,
        ranges,
        lowest = head
           (sort (map do i: ranges[i].lowest done instruments)),
        highest = head (reverse
           (sort (map do i: ranges[i].highest done instruments))),
    });

epsilon = 1e-16;

select predicate = concatMap do v: if predicate v then [v] else [] fi done;

distributionsFor data instrument note =
    {
        w = mat.getColumn note data.templates[instrument],
        p = vec.at data.pitches note,
        s = vec.at data.sources[instrument] note,
    };

performExpectation data column =
   (estimate = 
        fold do acc instrument:
            fold do acc note:
                { w, p, s } = distributionsFor data instrument note;
                vec.add [acc, vec.scaled (p * s) w];
            done acc [data.ranges[instrument].lowest .. 
                      data.ranges[instrument].highest]
        done (vec.consts epsilon (vec.length column)) data.instruments;
    vec.divide column estimate);

performMaximisation data column q =
   (pitches = vec.fromList
       (map do note:
            fold do acc instrument:
                { w, p, s } = distributionsFor data instrument note;
                fold do acc bin:
                    acc + s * (vec.at w bin) * (vec.at q bin);
                done acc [0..vec.length column - 1]
            done epsilon data.instruments;
        done [data.lowest .. data.highest]);
    pitches = vec.divideBy (vec.sum pitches) pitches;

    sources = mapIntoHash id
        do instrument: vec.fromList
           (map do note:
                if not inRange data.ranges instrument note then epsilon else
                    { w, p, s } = distributionsFor data instrument note;
                    fold do acc bin:
                        acc + (vec.at w bin) * (vec.at q bin);
                    done epsilon [0..vec.length column - 1]
                fi;
            done [data.lowest .. data.highest])
        done data.instruments;
   
    sourceDenoms = fold do acc instrument:
        vec.add [acc, sources[instrument]]
    done (vec.zeros (data.highest - data.lowest + 1)) data.instruments;

        \() (plot.plot [ Caption "Source numerators for piano", Vector (sources["piano-maps-SptkBGCl"]) ]);

        \() (plot.plot [ Caption "Source denominators", Vector sourceDenoms ]);
    
    sources = mapIntoHash id
        do instrument: 
            vec.divide sources[instrument] sourceDenoms;
        done (keys sources);
 
    data with { 
        pitches,
        sources,
    });

{
    initialise,
    performExpectation,
    performMaximisation,
}