view yeti/em.yeti @ 372:af71cbdab621 tip

Update bqvec code
author Chris Cannam
date Tue, 19 Nov 2019 10:13:32 +0000
parents cd9fd74931bb
children
line wrap: on
line source

module em;

mm = load may.mathmisc;
vec = load may.vector;
mat = load may.matrix;

inRange ranges inst note =
    note >= ranges[inst].lowest and note <= ranges[inst].highest;

normaliseColumn v =
   (s = vec.sum v;
    if s > 0 then vec.divideBy s v 
    else v
    fi);

normaliseChunk =
    mat.mapColumns normaliseColumn;

normaliseSources sourceMatrices =
   (denom = fold do acc source: mat.sum [acc, source] done
       (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices;
    array
       (map do source: mat.entryWiseDivide source denom done sourceMatrices));

initialise ranges templates notes size =
   (instrumentNames = sort (keys ranges);
    ranges = array (map (at ranges) instrumentNames);
    {
        pitches = // z in the original. 1xN per note
            normaliseChunk
               (mat.randomMatrix { rows = notes, columns = size.columns }),
        sources = // u in the original. a tensor, 1xN per note-instrument
            normaliseSources
               (array
                   (map do inst:
                        mat.tiledTo { rows = notes, columns = size.columns }
                           (mat.newColumnVector
                               (vec.fromList
                                   (map do note:
                                        if inRange ranges inst note
                                        then 1 
                                        else 0
                                        fi
                                    done [0..notes-1])))
                    done [0..length instrumentNames - 1])),
        instrumentNames,
        nInstruments = length instrumentNames,
        nNotes = notes,
        ranges,
        templates = array 
           (map do iname:
                normaliseChunk templates[iname];
            done instrumentNames),
        lowest = head (sort (map do r: r.lowest done ranges)),
        highest = head (reverse (sort (map do r: r.highest done ranges))),
    });

epsilon = 1e-16;

select predicate = concatMap do v: if predicate v then [v] else [] fi done;

distributionsFor data inst note =
    {
        w = mat.newColumnVector (mat.getColumn note data.templates[inst]),
        p = mat.newRowVector (mat.getRow note data.pitches),
        s = mat.newRowVector (mat.getRow note data.sources[inst]),
    };

performExpectation data chunk =
   (estimate = 
        fold do acc inst:
            fold do acc note:
                { w, p, s } = distributionsFor data inst note;
                mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ];
            done acc [data.ranges[inst].lowest .. 
                      data.ranges[inst].highest]
        done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1];
    { estimate, q = mat.entryWiseDivide chunk estimate});

performMaximisation data chunk q =
   (chunk = normaliseChunk chunk;
    columns = mat.width chunk;
    e = mat.constMatrix epsilon { rows = 1, columns };

    noteInstrumentProducts = [:];

    pitches = mat.concatVertical
       (map do note:
            if note < data.lowest or note > data.highest then e else
                fold do acc inst:
                    { w, p, s } = distributionsFor data inst note;
                    prod =
                        fold do acc bin:
                            mat.sum
                               [acc, 
                                mat.scaled (mat.at w bin 0) 
                                   (mat.entryWiseProduct
                                       [p, s, mat.newRowVector (mat.getRow bin chunk)])]
                        done e [0..mat.height chunk - 1];
                    noteInstrumentProducts[{inst,note}] := prod;
                    mat.sum [acc, prod];
                done e [0..data.nInstruments-1]
            fi;
        done [0..data.nNotes - 1]);
    pitches = normaliseChunk pitches;

    sources = array
       (map do inst:
           (mat.concatVertical
               (map do note:
                    if not inRange data.ranges inst note then e else
                        noteInstrumentProducts[{inst,note}]
                    fi
                done [0..data.nNotes - 1]))
        done [0..data.nInstruments-1]);
    sources = normaliseSources sources;

    data with {
        pitches,
        sources,
    });



{
    initialise,
    performExpectation,
    performMaximisation,
}