diff yeti/em.yeti @ 26:fbc4011c7693

Now to try to get the matrix version working!
author Chris Cannam
date Mon, 31 Mar 2014 17:03:51 +0100
parents d75dd38a12a5
children cd9fd74931bb
line wrap: on
line diff
--- a/yeti/em.yeti	Mon Mar 31 15:41:15 2014 +0100
+++ b/yeti/em.yeti	Mon Mar 31 17:03:51 2014 +0100
@@ -5,67 +5,125 @@
 vec = load may.vector;
 mat = load may.matrix;
 
-inRange ranges instrument note =
-    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
+inRange ranges inst note =
+    note >= ranges[inst].lowest and note <= ranges[inst].highest;
+
+normaliseColumn v =
+   (s = vec.sum v;
+    if s > 0 then vec.divideBy s v 
+    else v
+    fi);
+
+normaliseChunk =
+    mat.mapColumns normaliseColumn;
+
+normaliseSources sourceMatrices =
+   (denom = fold do acc source: mat.sum [acc, source] done
+       (mat.zeroMatrix (mat.size sourceMatrices[0])) sourceMatrices;
+    array
+       (map do source: mat.entryWiseDivide source denom done sourceMatrices));
 
 initialise ranges templates notes size =
-   (instruments = keys ranges;
+   (instrumentNames = sort (keys ranges);
+    ranges = array (map (at ranges) instrumentNames);
     {
         pitches = // z in the original. 1xN per note
-            array (map do note:
-                mat.randomMatrix { rows = 1, columns = size.columns }
-            done [0..notes-1]),
-        sources = // u in the original. 1xN per note-instrument
-            mapIntoHash id
-                do instrument:
-                    array (map do note:
-                        mat.constMatrix
-                           (if inRange ranges instrument note then 1 else 0 fi)
-                           (size with { rows = 1 })
-                    done [0..notes-1])
-                done instruments,
-        instruments,
-        templates,
+            normaliseChunk
+               (mat.randomMatrix { rows = notes, columns = size.columns }),
+        sources = // u in the original. a tensor, 1xN per note-instrument
+            normaliseSources
+               (array
+                   (map do inst:
+                        mat.tiledTo { rows = notes, columns = size.columns }
+                           (mat.newColumnVector
+                               (vec.fromList
+                                   (map do note:
+                                        if inRange ranges inst note
+                                        then 1 
+                                        else 0
+                                        fi
+                                    done [0..notes-1])))
+                    done [0..length instrumentNames - 1])),
+        instrumentNames,
+        nInstruments = length instrumentNames,
+        nNotes = notes,
         ranges,
-        lowest = head (sort (map do i: ranges[i].lowest done instruments)),
-        highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
+        templates = array 
+           (map do iname:
+                normaliseChunk templates[iname];
+            done instrumentNames),
+        lowest = head (sort (map do r: r.lowest done ranges)),
+        highest = head (reverse (sort (map do r: r.highest done ranges))),
     });
 
 epsilon = 1e-16;
 
 select predicate = concatMap do v: if predicate v then [v] else [] fi done;
 
+distributionsFor data inst note =
+    {
+        w = mat.newColumnVector (mat.getColumn note data.templates[inst]),
+        p = mat.newRowVector (mat.getRow note data.pitches),
+        s = mat.newRowVector (mat.getRow note data.sources[inst]),
+    };
+
 performExpectation data chunk =
    (estimate = 
-        fold do acc instrument:
+        fold do acc inst:
             fold do acc note:
-                template = mat.getColumn note data.templates[instrument];
-                resize = mat.tiledTo (mat.size chunk);
-                w = resize (mat.newColumnVector template);
-                p = resize data.pitches[note];
-                s = resize data.sources[instrument][note];
-                mat.sum [acc, mat.entryWiseProduct [w, p, s]];
-            done acc [data.ranges[instrument].lowest .. 
-                      data.ranges[instrument].highest]
-        done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
-    mat.entryWiseDivide chunk estimate);
+                { w, p, s } = distributionsFor data inst note;
+                mat.sum [acc, mat.product w (mat.entryWiseProduct [p, s]) ];
+            done acc [data.ranges[inst].lowest .. 
+                      data.ranges[inst].highest]
+        done (mat.constMatrix epsilon (mat.size chunk)) [0..data.nInstruments-1];
+    { estimate, q = mat.entryWiseDivide chunk estimate});
 
-performMaximisation data chunk error =
-   (pitches =
-        fold do acc note:
-            fold do acc instrument:
-                // want sum of error * original for all template and instruments
-                // for this pitch, divided by sum of error * original for all
-                // template and instruments for all pitches
+performMaximisation data chunk q =
+   (chunk = normaliseChunk chunk;
+    columns = mat.width chunk;
+    e = mat.constMatrix epsilon { rows = 1, columns };
 
-            template = mat.getColumn note data.templates[instrument];
-            w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
-            p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
-            s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
-            mat.sum [acc, mat.entryWiseProduct [w, p, s, error]]
-        done acc (select do i: inRange data.ranges i note done data.instruments)
-    done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest];
-    pitches);
+    pitches = mat.concatVertical
+       (map do note:
+            if note < data.lowest or note > data.highest then e else
+                fold do acc inst:
+                    { w, p, s } = distributionsFor data inst note;
+                    fold do acc bin:
+                         mat.sum
+                            [acc, 
+                             mat.scaled (mat.at w bin 0) 
+                                (mat.entryWiseProduct
+                                    [p, s, mat.newRowVector (mat.getRow bin chunk)])]
+                    done acc [0..mat.height chunk - 1]
+                done e [0..data.nInstruments-1]
+            fi;
+        done [0..data.nNotes - 1]);
+    pitches = normaliseChunk pitches;
+
+    sources = array
+       (map do inst:
+           (mat.concatVertical
+               (map do note:
+                    if not inRange data.ranges inst note then e else
+                        { w, p, s } = distributionsFor data inst note;
+                        fold do acc bin:
+                            mat.sum
+                               [acc,
+                                mat.scaled (mat.at w bin 0)
+                                   (mat.entryWiseProduct
+                                       [p, s, mat.newRowVector (mat.getRow bin chunk)])]
+                        done e [0..mat.height chunk - 1]
+                    fi
+                done [0..data.nNotes - 1]))
+        done [0..data.nInstruments-1]);
+    sources = normaliseSources sources;
+
+    data with {
+        pitches,
+        sources,
+    });
+
+
 
 {
     initialise,