changeset 424:e91fc47affd8

Make matrix filter always return a sparse matrix; add matrix "any" and "all"; use these in tests
author Chris Cannam
date Fri, 04 Oct 2013 12:01:19 +0100
parents ba316d36390e
children 3820e2a696f8
files src/may/matrix.yeti src/may/matrix/test/test_matrix.yeti src/may/stream/test/test_resample.yeti src/may/test/test.yeti
diffstat 4 files changed, 67 insertions(+), 45 deletions(-) [+]
line wrap: on
line diff
--- a/src/may/matrix.yeti	Fri Oct 04 11:01:18 2013 +0100
+++ b/src/may/matrix.yeti	Fri Oct 04 12:01:19 2013 +0100
@@ -294,6 +294,7 @@
 // cells which need to be filtered out. This is the public API for
 // makeSparse and is also used to discard out-of-range cells from
 // resizedTo.
+//!!! doc: i is row number, j is column number (throughout, for sparse stuff). Would calling them row/column be better?
 newSparseMatrix type size data =
     makeSparse type size
        (filter
@@ -474,19 +475,17 @@
         newMatrix (typeOf m) (map bf.abs (asColumns m));
     fi;
 
+//!!! doc: filter by predicate, always returns sparse matrix
 filter' f m =
-    if isSparse? m then
-        makeSparse (typeOf m) (size m)
-           (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
-               (enumerate m))
-    else
-        vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
-        if isRowMajor? m then
-            newMatrix (typeOf m) (map vfilter (asRows m));
-        else
-            newMatrix (typeOf m) (map vfilter (asColumns m));
-        fi;
-    fi;
+    makeSparse (typeOf m) (size m)
+       (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
+           (enumerate m));
+
+any' f m =
+    any f (map (.v) (enumerate m));
+
+all' f m =
+    all f (map (.v) (enumerate m));
 
 sparseProductLeft size m1 m2 =
    ({ values, indices, pointers } = case m1 of
@@ -779,6 +778,8 @@
     difference,
     abs = abs',
     filter = filter',
+    all = all',
+    any = any',
     product,
     entryWiseProduct,
     concat,
@@ -828,6 +829,8 @@
     difference is matrix -> matrix -> matrix,
     abs is matrix -> matrix,
     filter is (number -> boolean) -> matrix -> matrix,
+    all is (number -> boolean) -> matrix -> boolean,
+    any is (number -> boolean) -> matrix -> boolean,
     product is matrix -> matrix -> matrix,
     entryWiseProduct is matrix -> matrix -> matrix,
     concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
--- a/src/may/matrix/test/test_matrix.yeti	Fri Oct 04 11:01:18 2013 +0100
+++ b/src/may/matrix/test/test_matrix.yeti	Fri Oct 04 12:01:19 2013 +0100
@@ -502,9 +502,26 @@
     compareMatrices
        (mat.filter (> 2) m)
        (newMatrix (ColumnMajor ()) [[0,0,0],[0,0,6],[0,0,3]]) and
+        compare (mat.isSparse? (mat.filter (> 2) m)) true and
         compare (mat.density (mat.filter (> 2) m)) (2/9)
 ),
 
+"all-\(name)": \(
+    m = newMatrix (ColumnMajor ()) [[1,2,0],[-1,-4,6],[0,0,3]];
+    compare (mat.all (== 9) m) false and
+    compare (mat.all (!= 9) m) true and
+    compare (mat.all (== 2) m) false and
+    compare (mat.all (!= 2) m) false
+),
+
+"any-\(name)": \(
+    m = newMatrix (ColumnMajor ()) [[1,2,0],[-1,-4,6],[0,0,3]];
+    compare (mat.any (== 9) m) false and
+    compare (mat.any (!= 9) m) true and
+    compare (mat.any (== 2) m) true and
+    compare (mat.any (!= 2) m) true
+),
+
 "newSparseMatrix-\(name)": \(
     s = mat.newSparseMatrix (ColumnMajor ()) { rows = 2, columns = 3 } [
         { i = 0, j = 0, v = 1 },
--- a/src/may/stream/test/test_resample.yeti	Fri Oct 04 11:01:18 2013 +0100
+++ b/src/may/stream/test/test_resample.yeti	Fri Oct 04 12:01:19 2013 +0100
@@ -15,7 +15,7 @@
 
 //pl = { plot things = true; };
 
-{ compare, compareUsing, assert, time } = load may.test.test;
+{ compare, compareUsing, compareMatrices, assert, time } = load may.test.test;
 
 //!!! This and gcd should be Somewhere
 nextPowerOfTwo n =
@@ -30,15 +30,6 @@
     windowed = win.windowedRows win.hann data;
     syn.precalculated stream.sampleRate windowed);
 
-compareClose = compareUsing 
-    do m1 m2:
-        length m1 == length m2 and 
-            all id (map2 do v1 v2:
-                length v1 == length v2 and
-                    all id (map2 do a b: abs(a - b) < 1e-10 done v1 v2)
-                done m1 m2);
-    done;
-
 [
 
 // Test for duration of decimated stream (does not test contents, that
@@ -128,12 +119,7 @@
     result = output.read 32;
     reference = syn.sinusoid 16 2;
     expected = (windowedSignalFrom reference 32).read 32;
-    compareOutputs a b = compareClose
-       (map vec.list (mat.asRows a)) (map vec.list (mat.asRows b));
-\() (pl.plot [Vector (mat.getRow 0 result), Vector (mat.getRow 0 expected)]);
-println "diff is \(mat.difference result expected)";
-\() (pl.plot [Vector (mat.getRow 0 (mat.difference result expected))]);
-    compareOutputs result expected;
+    compareMatrices 1e-8 result expected;
 ),
 
 "decimated-sine": \(
@@ -146,14 +132,7 @@
     reference = syn.sinusoid 16 1;
 //    expected = mat.columnSlice (reference.read 200) 50 150;
     expected = (windowedSignalFrom reference 200).read 200;
-    compareOutputs a b = compareClose
-       (map vec.list (mat.asRows a)) (map vec.list (mat.asRows b));
-    if not compareOutputs result expected then
-\() (pl.plot [Vector (mat.getRow 0 result), Vector (mat.getRow 0 expected)]);
-        println "diff: \(mat.difference result expected)";
-\() (    pl.plot [ Vector (mat.getRow 0 (mat.difference result expected)) ] );
-	false
-    else true fi
+    compareMatrices 1e-8 result expected;
 ),
 */
 
@@ -179,7 +158,9 @@
     inmag = bf.divideBy incount (vec.resizedTo speclen (fft.realForwardMagnitude incount insig));
     outmag = bf.divideBy outcount (vec.resizedTo speclen (fft.realForwardMagnitude outcount outsig));
     diff = bf.subtract inmag outmag;
-//    compareClose [vec.list inmag] [vec.list outmag] or
+    compareMatrices 1e-7 (mat.newRowVector outmag) (mat.newRowVector inmag);
+/*
+//    compareMatrices 1e-8 [vec.list inmag] [vec.list outmag] or
         (//println "inmag: \(vec.list inmag)";
          //println "outmag: \(vec.list outmag)";
          //println "diff: \(vec.list diff)";
@@ -195,6 +176,7 @@
 //         \() (pl.plot [Vector inmag, Vector outmag]);
 //         \() (pl.plot [Vector diff]);
          false);
+*/
 ),
 
 
@@ -209,9 +191,9 @@
     up = resample.interpolated factor input;
     result = mat.getRow 0 (up.read (factor * vec.length data));
     phase = 0;
-    a = vec.list data;
-    b = map do i: vec.at result (i*factor + phase) done [0..vec.length data - 1];
-    compareClose [b] [a];
+    b = vec.fromList
+       (map do i: vec.at result (i*factor + phase) done [0..vec.length data - 1]);
+    compareMatrices 1e-8 (mat.newRowVector b) (mat.newRowVector data);
 ),
 
 "interpolated-rs-misc": \(
@@ -224,9 +206,9 @@
     up = resample.resampledTo (factor * input.sampleRate) input;
     result = mat.getRow 0 (up.read (factor * vec.length data));
     phase = 0;
-    a = vec.list data;
-    b = map do i: vec.at result (i*factor + phase) done [0..vec.length data - 1];
-    compareClose [b] [a];
+    b = vec.fromList
+       (map do i: vec.at result (i*factor + phase) done [0..vec.length data - 1]);
+    compareMatrices 1e-8 (mat.newRowVector b) (mat.newRowVector data);
 ),
 
 /*
@@ -248,7 +230,7 @@
         output.read (vec.length data));
 
     result = updown 0;
-    if not compareClose [vec.list (mat.getRow 0 result)] [vec.list data] then
+    if not compareMatrices 1e-8 [vec.list (mat.getRow 0 result)] [vec.list data] then
         \() (pl.plot [Vector data, 
                       Vector (mat.getRow 0 result), 
                       Vector (mat.getRow 0 (updown 1)),
--- a/src/may/test/test.yeti	Fri Oct 04 11:01:18 2013 +0100
+++ b/src/may/test/test.yeti	Fri Oct 04 12:01:19 2013 +0100
@@ -1,5 +1,7 @@
 module may.test.test;
 
+mat = load may.matrix;
+
 import yeti.lang: FailureException;
 
 var goodCompares = 0;
@@ -24,6 +26,24 @@
 
 compare obtained expected = compareUsing (==) obtained expected;
 
+compareMatrices tolerance obtained expected =
+   (d = mat.abs (mat.difference obtained expected);
+    if mat.all (< tolerance) d then
+        goodCompares := goodCompares + 1;
+        true;
+    else
+        println "** value(s) outside tolerance \(tolerance) from expected:";
+        count = 40;
+        faulty = mat.enumerate (mat.filter (>= tolerance) d);
+        for (take count faulty) do f:
+            println " * at (\(f.i),\(f.j)) expected: \(mat.at expected f.i f.j); obtained: \(mat.at obtained f.i f.j); diff: \(f.v)";
+        done;
+        if length faulty > count then
+            print "** (only first \(count) of \(length faulty) shown)";
+        fi;
+        false;
+    fi);
+
 time msg f =
    (start = System#currentTimeMillis();
     result = f ();
@@ -58,7 +78,7 @@
     bad);
 
 {
-    compare, compareUsing, assert,
+    compare, compareUsing, compareMatrices, assert,
     time,
     runTests, 
 }