changeset 238:0c86d9284f20 sparse

Implement sparse matrix construction, add tests for sparse matrices (currently failing)
author Chris Cannam
date Mon, 20 May 2013 14:18:14 +0100
parents 601dbfcf949d
children 741784624bb6
files yetilab/matrix/matrix.yeti yetilab/matrix/test/test_matrix.yeti
diffstat 2 files changed, 111 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Sun May 19 22:21:47 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Mon May 20 14:18:14 2013 +0100
@@ -159,14 +159,71 @@
      _: [];
     esac);
 
-makeSparse type data =
-//!!! not implemented
-    case type of
-    RowMajor (): 
-        SparseCSR { values = vec.zeros 0, indices = array [], pointers = array [], extent = 0 };
-    ColMajor (): 
-        SparseCSC { values = vec.zeros 0, indices = array [], pointers = array [], extent = 0 };
-    esac;
+makeSparse type size data =
+   (isRow = case type of RowMajor (): true; ColMajor (): false esac;
+    ordered = 
+        sortBy do a b:
+            if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
+        done
+           (map
+                if isRow then
+                    do { i, j, v }: { maj = i, min = j, v } done;
+                else
+                    do { i, j, v }: { maj = j, min = i, v } done;
+                fi
+                data);
+    tagger = if isRow then SparseCSR else SparseCSC fi;
+    majorSize = if isRow then size.rows else size.columns fi;
+    minorSize = if isRow then size.columns else size.rows fi;
+    majorPointers acc nn n i data =
+        if n < nn then 
+            case data of 
+            d::rest:
+                majorPointers (acc ++ (map \(i) [n..d-1])) nn d (i+1) rest;
+             _: 
+                majorPointers (acc ++ [i]) nn (n+1) i [];
+            esac;
+        else
+            acc
+        fi;
+    tagger {
+        values = vec.fromList (map (.v) ordered),
+        indices = array (map (.min) ordered),
+        pointers = array (majorPointers [] majorSize 0 0 (map (.maj) ordered)),
+        extent = minorSize,
+    });
+
+toSparse threshold m =
+    if isSparse? m then m
+    else
+        { rows, columns } = size m;
+        enumerate threshold m ii jj =
+            case ii of
+            i::irest:
+                case jj of
+                j::rest:
+                    v = getAt i j m;
+                    if abs v > threshold then
+                        { i, j, v } :. \(enumerate threshold m ii rest)
+                    else enumerate threshold m ii rest
+                    fi;
+                 _: enumerate threshold m irest [0..columns-1];
+                esac;
+             _: [];
+            esac;
+        makeSparse 
+            if isRowMajor? m then RowMajor () else ColMajor () fi
+               (size m)
+               (enumerate threshold m [0..rows-1] [0..columns-1]);
+    fi;
+
+toDense m =
+    if not (isSparse? m) then m
+    elif isRowMajor? m then
+        DenseRows (array (map do row: getRow row m done [0..height m - 1]));
+    else
+        DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
+    fi;
 
 constMatrix n = generate do row col: n done;
 randomMatrix = generate do row col: Math#random() done;
@@ -183,9 +240,9 @@
 flipped m =
     if isSparse? m then
         if isRowMajor? m then
-            makeSparse (ColMajor ()) (enumerateSparse m)
+            makeSparse (ColMajor ()) (size m) (enumerateSparse m)
         else
-            makeSparse (RowMajor ()) (enumerateSparse m)
+            makeSparse (RowMajor ()) (size m) (enumerateSparse m)
         fi
     else
         if isRowMajor? m then
@@ -203,30 +260,52 @@
 toColumnMajor m =
     if not isRowMajor? m then m else flipped m fi;
 
-equal' vecComparator m1 m2 =
+equal'' comparator vecComparator m1 m2 =
+    // Prerequisite: m1 and m2 have same sparse-p and storage order
+   (compareLists l1 l2 = all id (map2 comparator l1 l2);
+    compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
+    compareSparse d1 d2 =
+        d1.extent == d2.extent and
+        vecComparator d1.values d2.values and
+        compareLists d1.indices d2.indices and
+        compareLists d1.pointers d2.pointers;
+    case m1 of
+    DenseRows d1:
+        case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
+    DenseCols d1:
+        case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
+    SparseCSR d1:
+        case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
+    SparseCSC d1:
+        case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
+    esac);
+
+equal' comparator vecComparator m1 m2 =
     if size m1 != size m2 then 
         false
     elif isRowMajor? m1 != isRowMajor? m2 then
-        equal' vecComparator (flipped m1) m2;
+        equal' comparator vecComparator (flipped m1) m2;
+    elif isSparse? m1 != isSparse? m2 then
+        if isSparse? m1 then
+            equal' comparator vecComparator m1 (toSparse 0 m2)
+        else
+            equal' comparator vecComparator (toSparse 0 m1) m2
+        fi
     else
-        compare d1 d2 = all id (map2 vecComparator d1 d2);
-        case m1 of
-        DenseRows d1: case m2 of DenseRows d2: compare d1 d2; _: false; esac;
-        DenseCols d1: case m2 of DenseCols d2: compare d1 d2; _: false; esac;
-        esac
+        equal'' comparator vecComparator m1 m2
     fi;
 
 // Compare matrices using the given comparator for individual cells.
 // Note that matrices with different storage order but the same
 // contents are equal, although comparing them is slow.
 equalUnder comparator =
-    equal' (vec.equalUnder comparator);
+    equal' comparator (vec.equalUnder comparator);
 
 equal =
-    equal' vec.equal;
+    equal' (==) vec.equal;
 
 newMatrix type data = //!!! NB does not copy data
-   (tagger = case type of RowMajor (): RowM; ColumnMajor (): DenseCols esac;
+   (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
     if empty? data or vec.empty? (head data)
     then zeroSizeMatrix ()
     else tagger (array data)
@@ -383,6 +462,8 @@
     flipped,
     toRowMajor,
     toColumnMajor,
+    toSparse,
+    toDense,
     scaled,
     resizedTo,
     asRows,
@@ -422,6 +503,8 @@
     flipped is matrix -> matrix, 
     toRowMajor is matrix -> matrix, 
     toColumnMajor is matrix -> matrix,
+    toSparse is number -> matrix -> matrix,
+    toDense is matrix -> matrix,
     scaled is number -> matrix -> matrix,
     resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
     asRows is matrix -> list<vector>, 
--- a/yetilab/matrix/test/test_matrix.yeti	Sun May 19 22:21:47 2013 +0100
+++ b/yetilab/matrix/test/test_matrix.yeti	Mon May 20 14:18:14 2013 +0100
@@ -373,12 +373,16 @@
 
 ]);
 
-colhash = makeTests "column-major" id;
-rowhash = makeTests "row-major" mat.flipped;
+colhash = makeTests "column-dense" id;
+rowhash = makeTests "row-dense" mat.flipped;
+sparsecolhash = makeTests "column-sparse" (mat.toSparse 0);
+sparserowhash = makeTests "row-sparse" ((mat.toSparse 0) . (mat.flipped));
+//sparserowhash2 = makeTests "row-sparse2" ((mat.flipped) . (mat.toSparse 0));
 
 all = [:];
-for (keys colhash) do k: all[k] := colhash[k] done;
-for (keys rowhash) do k: all[k] := rowhash[k] done;
+for [ colhash, rowhash, sparsecolhash, sparserowhash /*, sparserowhash2 */ ] do h:
+    for (keys h) do k: all[k] := h[k] done;
+done;
 
 all is hash<string, () -> boolean>;