diff yetilab/matrix/matrix.yeti @ 261:53ff481f1a41

Implement and test sparse resizedTo; some tidying
author Chris Cannam
date Wed, 22 May 2013 15:02:17 +0100
parents de770971a628
children c206de7c3018
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Wed May 22 13:54:15 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Wed May 22 15:02:17 2013 +0100
@@ -141,6 +141,16 @@
     SparseCSC _: true;
     esac;
 
+typeOf m =
+    if isRowMajor? m then RowMajor ()
+    else ColumnMajor ()
+    fi;
+
+flippedTypeOf m =
+    if isRowMajor? m then ColumnMajor ()
+    else RowMajor ()
+    fi;
+
 newColumnMajorStorage { rows, columns } = 
     if rows < 1 then array []
     else array (map \(vec.zeros rows) [1..columns])
@@ -208,6 +218,8 @@
 enumerate m =
     if isSparse? m then enumerateSparse m else enumerateDense m fi;
 
+// Make a sparse matrix from entries whose i, j values are known to be
+// within range
 makeSparse type size data =
    (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
     ordered = 
@@ -244,13 +256,22 @@
         extent = minorSize,
     });
 
+// Make a sparse matrix from entries that may contain out-of-range
+// cells which need to be filtered out. This is the public API for
+// makeSparse and is also used to discard out-of-range cells from
+// resizedTo.
+newSparseMatrix type size data =
+    makeSparse type size
+       (filter
+            do { i, j, v }:
+                i == int i and i >= 0 and i < size.rows and 
+                j == int j and j >= 0 and j < size.columns
+            done data);
+
 toSparse m =
     if isSparse? m then m
     else
-        makeSparse 
-            if isRowMajor? m then RowMajor () else ColumnMajor () fi
-               (size m)
-               (enumerateDense m);
+        makeSparse (typeOf m) (size m) (enumerateDense m);
     fi;
 
 toDense m =
@@ -275,11 +296,7 @@
 
 flipped m =
     if isSparse? m then
-        if isRowMajor? m then
-            makeSparse (ColumnMajor ()) (size m) (enumerateSparse m)
-        else
-            makeSparse (RowMajor ()) (size m) (enumerateSparse m)
-        fi
+        makeSparse (flippedTypeOf m) (size m) (enumerateSparse m)
     else
         if isRowMajor? m then
             generate do row col: at' m row col done (size m);
@@ -356,10 +373,10 @@
 
 denseLinearOp op m1 m2 =
     if isRowMajor? m1 then
-        newMatrix (RowMajor ()) 
+        newMatrix (typeOf m1) 
            (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
     else
-        newMatrix (ColumnMajor ()) 
+        newMatrix (typeOf m1) 
            (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
     fi;
 
@@ -380,10 +397,7 @@
             kk = keys h[i];
             map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
             done (keys h));
-    makeSparse
-        if isRowMajor? m1 then (RowMajor ()) else (ColumnMajor ()) fi
-        (size m1) 
-        entries);
+    makeSparse (typeOf m1) (size m1) entries);
 
 sum' m1 m2 =
     if (size m1) != (size m2)
@@ -405,40 +419,35 @@
 
 scaled factor m =
     if isSparse? m then
-        makeSparse
-            if isRowMajor? m then (RowMajor ()) else (ColumnMajor ()) fi
-            (size m)
-            (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
+        makeSparse (typeOf m) (size m)
+           (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
     elif isRowMajor? m then
-        newMatrix (RowMajor ()) (map (bf.scaled factor) (asRows m));
+        newMatrix (typeOf m) (map (bf.scaled factor) (asRows m));
     else
-        newMatrix (ColumnMajor ()) (map (bf.scaled factor) (asColumns m));
+        newMatrix (typeOf m) (map (bf.scaled factor) (asColumns m));
     fi;
 
 abs' m =
     if isSparse? m then
-        makeSparse
-            if isRowMajor? m then (RowMajor ()) else (ColumnMajor ()) fi
-            (size m)
-            (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
+        makeSparse (typeOf m) (size m)
+           (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
     elif isRowMajor? m then
-        newMatrix (RowMajor ()) (map bf.abs (asRows m));
+        newMatrix (typeOf m) (map bf.abs (asRows m));
     else
-        newMatrix (ColumnMajor ()) (map bf.abs (asColumns m));
+        newMatrix (typeOf m) (map bf.abs (asColumns m));
     fi;
 
 filter f m =
     if isSparse? m then
-        makeSparse
-            if isRowMajor? m then (RowMajor ()) else (ColumnMajor ()) fi
-            (size m)
-            (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done (enumerate m))
+        makeSparse (typeOf m) (size m)
+           (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
+               (enumerate m))
     else
         vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
         if isRowMajor? m then
-            newMatrix (RowMajor ()) (map vfilter (asRows m));
+            newMatrix (typeOf m) (map vfilter (asRows m));
         else
-            newMatrix (ColumnMajor ()) (map vfilter (asColumns m));
+            newMatrix (typeOf m) (map vfilter (asColumns m));
         fi;
     fi;
 
@@ -549,9 +558,7 @@
                 ui uj rest);
          _: []
         esac;
-    makeSparse
-        if isRowMajor? first then (RowMajor ()) else (ColumnMajor ()) fi
-        { rows, columns }
+    makeSparse (typeOf first) { rows, columns }
         if direction == Vertical () then entries 0 0 1 0 mm
         else entries 0 0 0 1 mm fi);
 
@@ -606,10 +613,13 @@
         DenseRows (array (map do v: vec.slice v start end done (asRows m)))
     fi;
 
-//!!! needs sparse version
 resizedTo newsize m =
    (if newsize == (size m) then
         m
+    elif isSparse? m then
+        // don't call makeSparse directly: want to discard
+        // out-of-range cells
+        newSparseMatrix (typeOf m) newsize (enumerateSparse m)
     elif (height m) == 0 or (width m) == 0 then
         zeroMatrixWithTypeOf m newsize;
     else
@@ -676,7 +686,7 @@
     newMatrix,
     newRowVector,
     newColumnVector,
-    newSparseMatrix = makeSparse,
+    newSparseMatrix,
     enumerate
 }
 as