changeset 249:1ea5bf6e76b6 sparse

A reasonable sparse multiply, and a bit quicker dense one
author Chris Cannam
date Mon, 20 May 2013 22:17:19 +0100
parents 586d46f64902
children 6a141098c059 95f2565f9471
files yetilab/matrix/matrix.yeti yetilab/matrix/test/speedtest.yeti yetilab/matrix/test/test_matrix.yeti
diffstat 3 files changed, 96 insertions(+), 39 deletions(-) [+]
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Mon May 20 18:11:44 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Mon May 20 22:17:19 2013 +0100
@@ -46,21 +46,24 @@
 width m = (size m).columns;
 height m = (size m).rows;
 
-density m =
-   ({ rows, columns } = size m;
-    cells = rows * columns;
-    nonZeroCells d =
+nonZeroValues m =
+   (nz d =
         sum
            (map do v:
                 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
                 done d);
     case m of 
-    DenseRows d: (nonZeroCells d) / cells;
-    DenseCols d: (nonZeroCells d) / cells;
-    SparseCSR d: (vec.length d.values) / cells;
-    SparseCSC d: (vec.length d.values) / cells;
+    DenseRows d: nz d;
+    DenseCols d: nz d;
+    SparseCSR d: vec.length d.values;
+    SparseCSC d: vec.length d.values;
     esac);
 
+density m =
+   ({ rows, columns } = size m;
+    cells = rows * columns;
+    (nonZeroValues m) / cells);
+
 sparseSlice n d =
    (start = d.pointers[n];
     end = d.pointers[n+1];
@@ -309,6 +312,8 @@
 // Compare matrices using the given comparator for individual cells.
 // Note that matrices with different storage order but the same
 // contents are equal, although comparing them is slow.
+//!!! Document the fact that sparse matrices can only be equal if they
+// have the same set of non-zero cells (regardless of comparator used)
 equalUnder comparator =
     equal' comparator (vec.equalUnder comparator);
 
@@ -353,14 +358,50 @@
 abs' m =
     generate do row col: abs (getAt row col m) done (size m);
 
-//!!! todo: proper sparse multiply
+sparseProductLeft size m1 m2 =
+   (e = enumerateSparse m1;
+    data = array (map \(new double[size.rows]) [1..size.columns]);
+    for [0..size.columns - 1] do j':
+        c = getColumn j' m2;
+        for e do { v, i, j }:
+            data[j'][i] := data[j'][i] + v * (vec.at j c);
+        done;
+    done;
+    DenseCols (array (map vec.vector (list data))));
+
+sparseProductRight size m1 m2 =
+   (e = enumerateSparse m2;
+    data = array (map \(new double[size.columns]) [1..size.rows]);
+    for [0..size.rows - 1] do i':
+        r = getRow i' m1;
+        for e do { v, i, j }:
+            data[i'][j] := data[i'][j] + v * (vec.at i r);
+        done;
+    done;
+    DenseRows (array (map vec.vector (list data))));
+
+denseProduct size m1 m2 =
+   (data = array (map \(new double[size.rows]) [1..size.columns]);
+    for [0..size.rows - 1] do i:
+        row = getRow i m1;
+        for [0..size.columns - 1] do j:
+            data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
+        done;
+    done;
+    DenseCols (array (map vec.vector (list data))));
+
 product m1 m2 =
     if (size m1).columns != (size m2).rows
     then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
-    else
-        generate do row col:
-            bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
-        done { rows = (size m1).rows, columns = (size m2).columns }
+    else 
+        size = { rows = (size m1).rows, columns = (size m2).columns };
+        if isSparse? m1 then
+            sparseProductLeft size m1 m2
+        elif isSparse? m2 then
+            sparseProductRight size m1 m2
+        else
+            denseProduct size m1 m2
+        fi;
     fi;
 
 asRows m =
@@ -464,6 +505,7 @@
     width,
     height,
     density,
+    nonZeroValues,
     getAt,
     getColumn,
     getRow,
@@ -508,6 +550,7 @@
     width is matrix -> number,
     height is matrix -> number,
     density is matrix -> number,
+    nonZeroValues is matrix -> number,
     getAt is number -> number -> matrix -> number,
     getColumn is number -> matrix -> vector,
     getRow is number -> matrix -> vector,
--- a/yetilab/matrix/test/speedtest.yeti	Mon May 20 18:11:44 2013 +0100
+++ b/yetilab/matrix/test/speedtest.yeti	Mon May 20 22:17:19 2013 +0100
@@ -4,6 +4,10 @@
 mat = load yetilab.matrix.matrix;
 vec = load yetilab.vector.vector;
 
+{ compare, compareUsing } = load yetilab.test.test;
+
+compareMatrices = compareUsing mat.equal;
+
 time f =
    (start = System#currentTimeMillis();
     result = f ();
@@ -17,10 +21,10 @@
     makeSparse () = 
        (print "Making \(sparsity * 100)% sparse version (as dense matrix)...";
         t = time \(mat.thresholded sparsity m);
-        println "(Reported density: \(mat.density t))";
+        println "Reported density: \(mat.density t) (non-zero values: \(mat.nonZeroValues t))";
         print "Converting to sparse matrix...";
         s = time \(mat.toSparse t);
-        println "(Reported density: \(mat.density s))";
+        println "Reported density: \(mat.density s) (non-zero values: \(mat.nonZeroValues s))";
         s);
     s = makeSparse ();
     println "Making types:";
@@ -38,58 +42,48 @@
 println "\nR * M multiplies:\n";
 
 sz = 2000;
+sparsity = 0.98;
 
-{ cmd, rmd, cms, rms } = makeMatrices sz 0.98;
+{ cmd, rmd, cms, rms } = makeMatrices sz sparsity;
 
 row = mat.newRowVector (vec.fromList (map \(Math#random()) [1..sz]));
 col = mat.newColumnVector (vec.fromList (map \(Math#random()) [1..sz]));
 
 print "R * CMD... ";
-\() (time \(mat.product row cmd));
+a = (time \(mat.product row cmd));
 
 print "R * RMD... ";
-\() (time \(mat.product row rmd));
+b = (time \(mat.product row rmd));
 
 print "R * CMS... ";
-\() (time \(mat.product row cms));
+c = (time \(mat.product row cms));
 
 print "R * RMS... ";
-\() (time \(mat.product row rms));
+d = (time \(mat.product row rms));
+
+println "\nChecking results: \(compareMatrices a b) \(compareMatrices c d)";
 
 println "\nM * C multiplies:\n";
 
 print "CMD * C... ";
-\() (time \(mat.product cmd col));
+a = (time \(mat.product cmd col));
 
 print "RMD * C... ";
-\() (time \(mat.product rmd col));
+b = (time \(mat.product rmd col));
 
 print "CMS * C... ";
-\() (time \(mat.product cms col));
+c = (time \(mat.product cms col));
 
 print "RMS * C... ";
-\() (time \(mat.product rms col));
+d = (time \(mat.product rms col));
 
+println "\nChecking results: \(compareMatrices a b) \(compareMatrices c d)";
 
 println "\nM * M multiplies:\n";
 
 sz = 500;
 
-{ cmd, rmd, cms, rms } = makeMatrices sz 0.98;
-
-print "CMD * CMD... ";
-\() (time \(mat.product cmd cmd));
-
-print "CMD * RMD... ";
-\() (time \(mat.product cmd rmd));
-
-print "RMD * CMD... ";
-\() (time \(mat.product rmd cmd));
-
-print "RMD * RMD... ";
-\() (time \(mat.product rmd rmd));
-
-println "";
+{ cmd, rmd, cms, rms } = makeMatrices sz sparsity;
 
 print "CMS * CMD... ";
 \() (time \(mat.product cms cmd));
@@ -131,4 +125,18 @@
 print "RMS * RMS... ";
 \() (time \(mat.product rms rms));
 
+println "";
+
+print "CMD * CMD... ";
+\() (time \(mat.product cmd cmd));
+
+print "CMD * RMD... ";
+\() (time \(mat.product cmd rmd));
+
+print "RMD * CMD... ";
+\() (time \(mat.product rmd cmd));
+
+print "RMD * RMD... ";
+\() (time \(mat.product rmd rmd));
+
 ();
--- a/yetilab/matrix/test/test_matrix.yeti	Mon May 20 18:11:44 2013 +0100
+++ b/yetilab/matrix/test/test_matrix.yeti	Mon May 20 22:17:19 2013 +0100
@@ -354,6 +354,12 @@
         compare (mat.density (newMatrix (ColumnMajor ()) [[0,0,0],[0,0,0]])) 0
 ),
 
+"nonZeroValues-\(name)": \(
+    compare (mat.nonZeroValues (newMatrix (ColumnMajor ()) [[1,2,0],[0,5,0]])) 3 and
+        compare (mat.nonZeroValues (newMatrix (ColumnMajor ()) [[1,2,3],[4,5,6]])) 6 and
+        compare (mat.nonZeroValues (newMatrix (ColumnMajor ()) [[0,0,0],[0,0,0]])) 0
+),
+
 "toSparse-\(name)": \(
     m = newMatrix (ColumnMajor ()) [[1,2,0],[-1,-4,6],[0,0,3]];
     compareMatrices (mat.toSparse m) m and