diff yeti/em.yeti @ 14:a91de434feb8

More (sometimes baffled) annotations and a bit of work on the EM
author Chris Cannam
date Mon, 24 Mar 2014 16:31:20 +0000
parents e15bc63cb146
children 2b7257e4fc8a
line wrap: on
line diff
--- a/yeti/em.yeti	Fri Mar 21 18:12:38 2014 +0000
+++ b/yeti/em.yeti	Mon Mar 24 16:31:20 2014 +0000
@@ -5,27 +5,68 @@
 vec = load may.vector;
 mat = load may.matrix;
 
-initialiseEM ranges notes size =
+inRange ranges instrument note =
+    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
+
+initialise ranges templates notes size =
+   (instruments = keys ranges;
     {
-        pitches = // z in the original
+        pitches = // z in the original. 1xN per note
             array (map do note:
-                map \(mm.random ()) [0..size.columns-1]
+                mat.randomMatrix { rows = 1, columns = size.columns }
             done [0..notes-1]),
-        sources =
-            mapIntoHash id // u in the original
+        sources = // u in the original. 1xN per note-instrument
+            mapIntoHash id
                 do instrument:
                     array (map do note:
-                        if note >= ranges[instrument].lowestNote and
-                           note <= ranges[instrument].highestNote
-                        then vec.ones size.columns
-                        else vec.zeros size.columns
-                        fi
+                        mat.constMatrix
+                           (if inRange ranges instrument note then 1 else 0 fi)
+                           (size with { rows = 1 })
                     done [0..notes-1])
-                done (keys ranges);
-    };
+                done instruments,
+        instruments,
+        templates,
+        ranges,
+        lowest = head (sort (map do i: ranges[i].lowest done instruments)),
+        highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
+    });
+
+epsilon = 1e-16;
+
+select predicate = concatMap do v: if predicate v then [v] else [] fi done;
+
+performExpectation data chunk =
+   (estimate = 
+        fold do acc instrument:
+            fold do acc note:
+                template = mat.getColumn note data.templates[instrument];
+                w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
+                p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
+                s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
+                mat.sum [acc, mat.entryWiseProduct [w, p, s]];
+            done acc [data.ranges[instrument].lowest .. 
+                      data.ranges[instrument].highest]
+        done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
+    mat.entryWiseDivide chunk estimate);
+
+performMaximisation data chunk error =
+   (fold do acc note:
+        fold do acc instrument:
+            template = mat.getColumn note data.templates[instrument];
+            w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
+            p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
+            s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
+
+
+            mat.sum [acc, mat.entryWiseProduct [w, s, error]]
+
+        done acc (select do i: inRange data.ranges i note done data.instruments)
+    done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest]);
 
 {
-    initialiseEM
+    initialise,
+    performExpectation,
+    performMaximisation,
 }