changeset 258:f3b7b5d20f88

Replace matrix thresholded with a more general filter function that works for sparse matrices too
author Chris Cannam
date Wed, 22 May 2013 08:56:51 +0100
parents f00ab8baa6d7
children fae62dca8048
files yetilab/matrix/matrix.yeti yetilab/matrix/test/speedtest.yeti yetilab/matrix/test/test_matrix.yeti
diffstat 3 files changed, 22 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Tue May 21 22:37:28 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Wed May 22 08:56:51 2013 +0100
@@ -356,11 +356,6 @@
 newColumnVector data = //!!! NB does not copy data
     DenseCols (array [data]);
 
-thresholded threshold m = //!!! v inefficient; and should take a threshold function?
-    generate do row col:
-        v = getAt row col m; if (abs v) > threshold then v else 0 fi
-    done (size m);
-
 denseLinearOp op m1 m2 =
     if isRowMajor? m1 then
         newMatrix (RowMajor ()) 
@@ -434,6 +429,21 @@
         newMatrix (ColumnMajor ()) (map bf.abs (asColumns m));
     fi;
 
+filter f m =
+    if isSparse? m then
+        makeSparse
+            if isRowMajor? m then (RowMajor ()) else (ColumnMajor ()) fi
+            (size m)
+            (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done (enumerate m))
+    else
+        vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
+        if isRowMajor? m then
+            newMatrix (RowMajor ()) (map vfilter (asRows m));
+        else
+            newMatrix (ColumnMajor ()) (map vfilter (asColumns m));
+        fi;
+    fi;
+
 sparseProductLeft size m1 m2 =
    (e = enumerateSparse m1;
     data = array (map \(new double[size.rows]) [1..size.columns]);
@@ -627,13 +637,13 @@
     toSparse,
     toDense,
     scaled,
-    thresholded,
     resizedTo,
     asRows,
     asColumns,
     sum = sum',
     difference,
     abs = abs',
+    filter,
     product,
     concat,
     rowSlice,
@@ -680,6 +690,7 @@
     sum is matrix -> matrix -> matrix,
     difference is matrix -> matrix -> matrix,
     abs is matrix -> matrix,
+    filter is (number -> boolean) -> matrix -> matrix,
     product is matrix -> matrix -> matrix,
     concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
     rowSlice is number -> number -> matrix -> matrix, 
--- a/yetilab/matrix/test/speedtest.yeti	Tue May 21 22:37:28 2013 +0100
+++ b/yetilab/matrix/test/speedtest.yeti	Wed May 22 08:56:51 2013 +0100
@@ -20,7 +20,7 @@
     m = time \(mat.randomMatrix { rows = sz, columns = sz });
     makeSparse () = 
        (print "Making \(sparsity * 100)% sparse version (as dense matrix)...";
-        t = time \(mat.thresholded sparsity m);
+        t = time \(mat.filter (> sparsity) m);
         println "Reported density: \(mat.density t) (non-zero values: \(mat.nonZeroValues t))";
         print "Converting to sparse matrix...";
         s = time \(mat.toSparse t);
--- a/yetilab/matrix/test/test_matrix.yeti	Tue May 21 22:37:28 2013 +0100
+++ b/yetilab/matrix/test/test_matrix.yeti	Wed May 22 08:56:51 2013 +0100
@@ -449,12 +449,12 @@
         compareMatrices (mat.toSparse (mat.toDense m)) m
 ),
 
-"thresholded-\(name)": \(
+"filter-\(name)": \(
     m = newMatrix (ColumnMajor ()) [[1,2,0],[-1,-4,6],[0,0,3]];
     compareMatrices
-       (mat.thresholded 2 m)
-       (newMatrix (ColumnMajor ()) [[0,0,0],[0,-4,6],[0,0,3]]) and
-        compare (mat.density (mat.thresholded 2 m)) (3/9)
+       (mat.filter (> 2) m)
+       (newMatrix (ColumnMajor ()) [[0,0,0],[0,0,6],[0,0,3]]) and
+        compare (mat.density (mat.filter (> 2) m)) (2/9)
 ),
 
 "newSparseMatrix-\(name)": \(