changeset 14:a91de434feb8

More (sometimes baffled) annotations and a bit of work on the EM
author Chris Cannam
date Mon, 24 Mar 2014 16:31:20 +0000
parents e15bc63cb146
children 2b7257e4fc8a
files notes/cplcaMT-annotated.m yeti/em.yeti yeti/silvet.yeti yeti/templates.yeti
diffstat 4 files changed, 100 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/notes/cplcaMT-annotated.m	Fri Mar 21 18:12:38 2014 +0000
+++ b/notes/cplcaMT-annotated.m	Mon Mar 24 16:31:20 2014 +0000
@@ -267,13 +267,29 @@
         nh=eps;
         for r=1:R
             if( (pa(r,1) <= k &&  k <= pa(r,2)) )
+
+	        %% so we're accumulating to nh (which is per-note but
+                %% across all instruments) here
+
+	        %% this is a convolution of the error (xbar) with w,
+	        %% carried out as a frequency-domain multiplication.
+	        %% so it's like xbar-for-all-w
                 c = abs( real( ifftn( fx .* fw{r,k} )));
-                nh1 = eval( fnh);
-                nh1 = nh1 .*repmat(u{r,k},1,size(h{k},1))';
-                nh = nh + nh1;
+
+                nh1 = eval( fnh); %% this one is highly mysterious
+
+		%% take the 100x1 note range matrix, repeat to 100x5
+		%% (as h{k} is 5x100), transpose, multiply nh1 by that
+                nh1 = nh1 .* repmat(u{r,k},1,size(h{k},1))';
+                nh = nh + nh1; %% so nh will presumably be 100x5 too
                 
-                nhu = eval( fnh);
+                nhu = eval( fnh); %% more mystery
+
+		%% h{k} is 5x100, I'd expect this to be 100x5, I must
+		%% have got something transposed somewhere
                 nhu = nhu .* h{k};
+
+		%% so I guess this is xbar-for-all-w-for-all-h?
                 nu = sum(nhu)';
                 nu = u{r,k} .* nu + eps;
                 if lu
--- a/yeti/em.yeti	Fri Mar 21 18:12:38 2014 +0000
+++ b/yeti/em.yeti	Mon Mar 24 16:31:20 2014 +0000
@@ -5,27 +5,68 @@
 vec = load may.vector;
 mat = load may.matrix;
 
-initialiseEM ranges notes size =
+inRange ranges instrument note =
+    note >= ranges[instrument].lowest and note <= ranges[instrument].highest;
+
+initialise ranges templates notes size =
+   (instruments = keys ranges;
     {
-        pitches = // z in the original
+        pitches = // z in the original. 1xN per note
             array (map do note:
-                map \(mm.random ()) [0..size.columns-1]
+                mat.randomMatrix { rows = 1, columns = size.columns }
             done [0..notes-1]),
-        sources =
-            mapIntoHash id // u in the original
+        sources = // u in the original. 1xN per note-instrument
+            mapIntoHash id
                 do instrument:
                     array (map do note:
-                        if note >= ranges[instrument].lowestNote and
-                           note <= ranges[instrument].highestNote
-                        then vec.ones size.columns
-                        else vec.zeros size.columns
-                        fi
+                        mat.constMatrix
+                           (if inRange ranges instrument note then 1 else 0 fi)
+                           (size with { rows = 1 })
                     done [0..notes-1])
-                done (keys ranges);
-    };
+                done instruments,
+        instruments,
+        templates,
+        ranges,
+        lowest = head (sort (map do i: ranges[i].lowest done instruments)),
+        highest = head (reverse (sort (map do i: ranges[i].highest done instruments))),
+    });
+
+epsilon = 1e-16;
+
+select predicate = concatMap do v: if predicate v then [v] else [] fi done;
+
+performExpectation data chunk =
+   (estimate = 
+        fold do acc instrument:
+            fold do acc note:
+                template = mat.getColumn note data.templates[instrument];
+                w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
+                p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
+                s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
+                mat.sum [acc, mat.entryWiseProduct [w, p, s]];
+            done acc [data.ranges[instrument].lowest .. 
+                      data.ranges[instrument].highest]
+        done (mat.constMatrix epsilon (mat.size chunk)) data.instruments;
+    mat.entryWiseDivide chunk estimate);
+
+performMaximisation data chunk error =
+   (fold do acc note:
+        fold do acc instrument:
+            template = mat.getColumn note data.templates[instrument];
+            w = mat.repeatedHorizontal (mat.width chunk) (mat.newColumnVector template);
+            p = mat.repeatedVertical (mat.height chunk) data.pitches[note];
+            s = mat.repeatedVertical (mat.height chunk) data.sources[instrument][note];
+
+
+            mat.sum [acc, mat.entryWiseProduct [w, s, error]]
+
+        done acc (select do i: inRange data.ranges i note done data.instruments)
+    done (mat.constMatrix epsilon (mat.size chunk)) [data.lowest .. data.highest]);
 
 {
-    initialiseEM
+    initialise,
+    performExpectation,
+    performMaximisation,
 }
 
 
--- a/yeti/silvet.yeti	Fri Mar 21 18:12:38 2014 +0000
+++ b/yeti/silvet.yeti	Mon Mar 24 16:31:20 2014 +0000
@@ -3,10 +3,12 @@
 
 { prepareTimeFrequency } = load timefreq;
 { loadTemplates, extractRanges } = load templates;
-{ initialiseEM } = load em;
+
+em = load em;
 
 mat = load may.matrix;
 vec = load may.vector;
+plot = load may.plot;
 
 templates = loadTemplates ();
 
@@ -14,7 +16,7 @@
 
 eprintln "\nWe have \(length (keys templates)) instruments:";
 for (sort (keys templates)) do k:
-    eprintln " * \(k) \(mat.size templates[k]) range \(ranges[k].lowestNote) -> \(ranges[k].highestNote)";
+    eprintln " * \(k) \(mat.size templates[k]) range \(ranges[k].lowest) -> \(ranges[k].highest)";
 done;
 eprintln "";
 
@@ -24,9 +26,9 @@
 
 chunkSize = { rows = height, columns = 100 };
 
-emdata = initialiseEM ranges 88 chunkSize;
+emdata = em.initialise ranges templates 88 chunkSize;
 
-eprintln "initialised EM data";
+eprintln "initialised EM data: overall pitch range \(emdata.lowest) -> \(emdata.highest)";
 
 chunkify cols = 
     if empty? cols then []
@@ -40,4 +42,23 @@
 
 eprintln "we have \(length chunks) chunks of size \(mat.size (head chunks))";
 
+eprintln "attempting one expectation phase...";
 
+error = em.performExpectation emdata (head chunks);
+
+eprintln "done, result has dimension \(mat.size error)";
+
+eprintln "attempting one maximisation phase...";
+
+newP = em.performMaximisation emdata (head chunks) error;
+
+eprintln "done";
+
+\() (plot.plot [ Grid (head chunks) ]);
+\() (plot.plot [ Grid error ]);
+
+\() (plot.plot [ Grid newP ]);
+
+();
+
+
--- a/yeti/templates.yeti	Fri Mar 21 18:12:38 2014 +0000
+++ b/yeti/templates.yeti	Mon Mar 24 16:31:20 2014 +0000
@@ -36,7 +36,7 @@
         levels = map vec.sum (mat.asColumns (templates[instrument]));
         first = length levels - length (find (>0) levels);
         last = length (find (>0) (reverse levels)) - 1;
-        { lowestNote = first, highestNote = last }
+        { lowest = first, highest = last }
     done (keys templates);
 
 {