view yeti/em_onecolumn.yeti @ 24:0e8ee830b5ee

Normalise source distributions row-wise (per note)
author Chris Cannam
date Mon, 31 Mar 2014 12:46:24 +0100
parents 990b8b8b7e25
children 80f02ff5d37a
line wrap: on
line source

module em_onecolumn;

mm = load may.mathmisc;
vec = load may.vector;
mat = load may.matrix;

plot = load may.plot;

inRange ranges instrument note =
    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;

normalise v =
   (s = vec.sum v;
    if s > 0 then vec.divideBy s v 
    else v
    fi);

initialise ranges templates notes =
   (instruments = sort (keys ranges);
    {
        pitches = // z in the original. 1 per note
            normalise (vec.randoms notes),
        sources = // u in the original. 1 per note-instrument
            mat.fromRows
               (map do note:
                   normalise
                      (vec.fromList
                          (map do instrument:
                               if inRange ranges instrument note then 1 else 0 fi
                           done instruments))
                done [0..notes-1]),
        instruments = array instruments,
        instCount = length instruments,
        noteCount = notes,
        templates = array 
           (map do iname:
                m = templates[iname];
                mat.fromColumns (map normalise (mat.asColumns m))
            done instruments),
        ranges = array
           (map do iname:
               ranges[iname]
            done instruments),
        lowest = head
           (sort (map do iname: ranges[iname].lowest done instruments)),
        highest = head (reverse
           (sort (map do iname: ranges[iname].highest done instruments))),
    });

epsilon = 1e-16;

select predicate = concatMap do v: if predicate v then [v] else [] fi done;

distributionsFor data instNo note =
    {
        w = mat.getColumn note data.templates[instNo],
        p = vec.at data.pitches note,
        s = mat.at data.sources note instNo,
    };

performExpectation data column =
   (column = normalise column;
    estimate = 
        fold do acc inst:
            fold do acc note:
                { w, p, s } = distributionsFor data inst note;
                vec.add [acc, vec.scaled (p * s) w];
            done acc [data.ranges[inst].lowest .. 
                      data.ranges[inst].highest]
        done (vec.consts epsilon (vec.length column)) [0..data.instCount-1];
    vec.divide column estimate);

performMaximisation data column q =
   (column = normalise column;

    pitches = vec.fromList
       (map do note:
            if note >= data.lowest and note <= data.highest then
                fold do acc inst:
                    { w, p, s } = distributionsFor data inst note;
                    fold do acc bin:
                        acc + s * p * (vec.at w bin) * (vec.at q bin);
                    done acc [0..vec.length column - 1]
                done epsilon [0..data.instCount-1];
           else epsilon
           fi
        done [0..data.noteCount-1]);
    pitches = vec.divideBy (vec.sum pitches) pitches;

    sources = mat.fromRows
       (map do note: vec.fromList
           (map do inst:
                if not inRange data.ranges inst note then epsilon else
                    { w, p, s } = distributionsFor data inst note;
                    fold do acc bin:
                        acc + s * p * (vec.at w bin) * (vec.at q bin);
                    done epsilon [0..vec.length column - 1]
                fi;
            done [0..data.instCount-1])
        done [0..data.noteCount-1]);
   
    sourceDenoms = fold do acc inst:
        vec.add [acc, (mat.getColumn inst sources)]
    done (vec.zeros data.noteCount) [0..data.instCount-1];

    //!!! shouldn't this normalisation be the other way? (rows, not columns)    
/*    sources = mat.fromColumns
       (map do inst: 
            vec.divide (mat.getColumn inst sources) sourceDenoms;
        done [0..data.instCount-1]);
*/
    sources = mat.fromRows
       (map do note: 
            vec.divide (mat.getRow note sources) sourceDenoms;
        done [0..data.noteCount-1]);

    data with { 
        pitches,
        sources,
    });

{
    initialise,
    performExpectation,
    performMaximisation,
}