diff yetilab/matrix/matrix.yeti @ 249:1ea5bf6e76b6 sparse

A reasonable sparse multiply, and a bit quicker dense one
author Chris Cannam
date Mon, 20 May 2013 22:17:19 +0100
parents 58e98d146dc1
children 9fe3192cce38
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Mon May 20 18:11:44 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Mon May 20 22:17:19 2013 +0100
@@ -46,21 +46,24 @@
 width m = (size m).columns;
 height m = (size m).rows;
 
-density m =
-   ({ rows, columns } = size m;
-    cells = rows * columns;
-    nonZeroCells d =
+nonZeroValues m =
+   (nz d =
         sum
            (map do v:
                 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
                 done d);
     case m of 
-    DenseRows d: (nonZeroCells d) / cells;
-    DenseCols d: (nonZeroCells d) / cells;
-    SparseCSR d: (vec.length d.values) / cells;
-    SparseCSC d: (vec.length d.values) / cells;
+    DenseRows d: nz d;
+    DenseCols d: nz d;
+    SparseCSR d: vec.length d.values;
+    SparseCSC d: vec.length d.values;
     esac);
 
+density m =
+   ({ rows, columns } = size m;
+    cells = rows * columns;
+    (nonZeroValues m) / cells);
+
 sparseSlice n d =
    (start = d.pointers[n];
     end = d.pointers[n+1];
@@ -309,6 +312,8 @@
 // Compare matrices using the given comparator for individual cells.
 // Note that matrices with different storage order but the same
 // contents are equal, although comparing them is slow.
+//!!! Document the fact that sparse matrices can only be equal if they
+// have the same set of non-zero cells (regardless of comparator used)
 equalUnder comparator =
     equal' comparator (vec.equalUnder comparator);
 
@@ -353,14 +358,50 @@
 abs' m =
     generate do row col: abs (getAt row col m) done (size m);
 
-//!!! todo: proper sparse multiply
+sparseProductLeft size m1 m2 =
+   (e = enumerateSparse m1;
+    data = array (map \(new double[size.rows]) [1..size.columns]);
+    for [0..size.columns - 1] do j':
+        c = getColumn j' m2;
+        for e do { v, i, j }:
+            data[j'][i] := data[j'][i] + v * (vec.at j c);
+        done;
+    done;
+    DenseCols (array (map vec.vector (list data))));
+
+sparseProductRight size m1 m2 =
+   (e = enumerateSparse m2;
+    data = array (map \(new double[size.columns]) [1..size.rows]);
+    for [0..size.rows - 1] do i':
+        r = getRow i' m1;
+        for e do { v, i, j }:
+            data[i'][j] := data[i'][j] + v * (vec.at i r);
+        done;
+    done;
+    DenseRows (array (map vec.vector (list data))));
+
+denseProduct size m1 m2 =
+   (data = array (map \(new double[size.rows]) [1..size.columns]);
+    for [0..size.rows - 1] do i:
+        row = getRow i m1;
+        for [0..size.columns - 1] do j:
+            data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
+        done;
+    done;
+    DenseCols (array (map vec.vector (list data))));
+
 product m1 m2 =
     if (size m1).columns != (size m2).rows
     then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
-    else
-        generate do row col:
-            bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
-        done { rows = (size m1).rows, columns = (size m2).columns }
+    else 
+        size = { rows = (size m1).rows, columns = (size m2).columns };
+        if isSparse? m1 then
+            sparseProductLeft size m1 m2
+        elif isSparse? m2 then
+            sparseProductRight size m1 m2
+        else
+            denseProduct size m1 m2
+        fi;
     fi;
 
 asRows m =
@@ -464,6 +505,7 @@
     width,
     height,
     density,
+    nonZeroValues,
     getAt,
     getColumn,
     getRow,
@@ -508,6 +550,7 @@
     width is matrix -> number,
     height is matrix -> number,
     density is matrix -> number,
+    nonZeroValues is matrix -> number,
     getAt is number -> number -> matrix -> number,
     getColumn is number -> matrix -> vector,
     getRow is number -> matrix -> vector,