changeset 253:95f2565f9471 sparse

Merge from default branch
author Chris Cannam
date Tue, 21 May 2013 16:00:21 +0100
parents 1ea5bf6e76b6 (current diff) efdb1aee9d21 (diff)
children 5eb57c649de0
files
diffstat 4 files changed, 133 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Mon May 20 22:17:19 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Tue May 21 16:00:21 2013 +0100
@@ -72,6 +72,15 @@
         indices = slice d.indices start end,
     });
 
+nonEmptySlices d =
+   (ne = array [];
+    for [0..length d.pointers - 2] do i:
+        if d.pointers[i] != d.pointers[i+1] then
+            push ne i
+        fi
+    done;
+    ne);
+
 fromSlice n m d =
    (slice = sparseSlice n d;
     var v = 0;
@@ -191,21 +200,23 @@
     tagger = if isRow then SparseCSR else SparseCSC fi;
     majorSize = if isRow then size.rows else size.columns fi;
     minorSize = if isRow then size.columns else size.rows fi;
-    majorPointers acc nn n i data =
-        if n < nn then 
-            case data of 
+    pointers = array [0];
+    setArrayCapacity pointers (size.rows + 1);
+    fillPointers n i data =
+        if n < majorSize then
+            case data of
             d::rest:
-                majorPointers (acc ++ (map \(i) [n..d-1])) nn d (i+1) rest;
-             _: 
-                majorPointers (acc ++ [i]) nn (n+1) i [];
+               (for [n..d-1] \(push pointers i);
+                fillPointers d (i+1) rest);
+             _:
+                for [n..majorSize-1] \(push pointers i);
             esac;
-        else
-            acc
         fi;
+    fillPointers 0 0 (map (.maj) ordered);
     tagger {
         values = vec.fromList (map (.v) ordered),
         indices = array (map (.min) ordered),
-        pointers = array (majorPointers [0] majorSize 0 0 (map (.maj) ordered)),
+        pointers,
         extent = minorSize,
     });
 
@@ -380,6 +391,33 @@
     done;
     DenseRows (array (map vec.vector (list data))));
 
+sparseProduct size m1 m2 =
+    case m2 of
+    SparseCSC d:
+       (e = enumerateSparse m1;
+        entries =
+           (map do j':
+                cs = sparseSlice j' d;
+                hin = mapIntoHash
+                   (at cs.indices) ((flip vec.at) cs.values)
+                   [0..length cs.indices - 1];
+                hout = [:];
+                for e do { v, i, j }:
+                    if j in hin then
+                        p = v * hin[j];
+                        hout[i] := p + (if i in hout then hout[i] else 0 fi);
+                    fi
+                done;
+                map do i:
+                    { i, j = j', v = hout[i] }
+                done (keys hout);
+            done (nonEmptySlices d));
+        makeSparse (ColumnMajor ()) size (concat entries));
+    SparseCSR _:
+        sparseProduct size m1 (flipped m2);
+     _: failWith "sparseProduct called for non-sparse matrices";
+    esac;
+
 denseProduct size m1 m2 =
    (data = array (map \(new double[size.rows]) [1..size.columns]);
     for [0..size.rows - 1] do i:
@@ -396,7 +434,11 @@
     else 
         size = { rows = (size m1).rows, columns = (size m2).columns };
         if isSparse? m1 then
-            sparseProductLeft size m1 m2
+            if isSparse? m2 then
+                sparseProduct size m1 m2
+            else
+                sparseProductLeft size m1 m2
+            fi
         elif isSparse? m2 then
             sparseProductRight size m1 m2
         else
--- a/yetilab/matrix/test/speedtest.yeti	Mon May 20 22:17:19 2013 +0100
+++ b/yetilab/matrix/test/speedtest.yeti	Tue May 21 16:00:21 2013 +0100
@@ -79,6 +79,10 @@
 
 println "\nChecking results: \(compareMatrices a b) \(compareMatrices c d)";
 
+reportOn m = 
+   (print "                                 ";
+    println "isSparse: \(mat.isSparse? m), density \(mat.density m)");
+
 println "\nM * M multiplies:\n";
 
 sz = 500;
@@ -86,57 +90,93 @@
 { cmd, rmd, cms, rms } = makeMatrices sz sparsity;
 
 print "CMS * CMD... ";
-\() (time \(mat.product cms cmd));
+reportOn (time \(mat.product cms cmd));
 
 print "CMS * RMD... ";
-\() (time \(mat.product cms rmd));
+reportOn (time \(mat.product cms rmd));
 
 print "RMS * CMD... ";
-\() (time \(mat.product rms cmd));
+reportOn (time \(mat.product rms cmd));
 
 print "RMS * RMD... ";
-\() (time \(mat.product rms rmd));
+reportOn (time \(mat.product rms rmd));
 
 println "";
 
 print "CMD * CMS... ";
-\() (time \(mat.product cmd cms));
+reportOn (time \(mat.product cmd cms));
 
 print "CMD * RMS... ";
-\() (time \(mat.product cmd rms));
+reportOn (time \(mat.product cmd rms));
 
 print "RMD * CMS... ";
-\() (time \(mat.product rmd cms));
+reportOn (time \(mat.product rmd cms));
 
 print "RMD * RMS... ";
-\() (time \(mat.product rmd rms));
+reportOn (time \(mat.product rmd rms));
 
 println "";
 
 print "CMS * CMS... ";
-\() (time \(mat.product cms cms));
+reportOn (time \(mat.product cms cms));
 
 print "CMS * RMS... ";
-\() (time \(mat.product cms rms));
+reportOn (time \(mat.product cms rms));
 
 print "RMS * CMS... ";
-\() (time \(mat.product rms cms));
+reportOn (time \(mat.product rms cms));
 
 print "RMS * RMS... ";
-\() (time \(mat.product rms rms));
+reportOn (time \(mat.product rms rms));
 
 println "";
 
 print "CMD * CMD... ";
-\() (time \(mat.product cmd cmd));
+reportOn (time \(mat.product cmd cmd));
 
 print "CMD * RMD... ";
-\() (time \(mat.product cmd rmd));
+reportOn (time \(mat.product cmd rmd));
 
 print "RMD * CMD... ";
-\() (time \(mat.product rmd cmd));
+reportOn (time \(mat.product rmd cmd));
 
 print "RMD * RMD... ";
-\() (time \(mat.product rmd rmd));
+reportOn (time \(mat.product rmd rmd));
+
+println "\nLarge sparse M * M multiplies:\n";
+
+sz = 5000000;
+nnz = 10000;
+
+print "Calculating \(nnz) non-zero entry records...";
+entries = time \(e = map \({ i = int (Math#random() * sz), 
+                             j = int (Math#random() * sz),
+                             v = Math#random() }) [1..nnz];
+                 \() (length e); // make sure list non-lazy for timing purposes
+                 e);
+
+print "Making \(sz) * \(sz) random matrix with \(nnz) entries...";
+rms = time \(mat.newSparseMatrix (RowMajor ()) { rows = sz, columns = sz }
+             entries);
+println "Reported density: \(mat.density rms) (non-zero values: \(mat.nonZeroValues rms))";
+
+print "Making col-major copy...";
+cms = time \(mat.toColumnMajor rms);
+println "Reported density: \(mat.density cms) (non-zero values: \(mat.nonZeroValues cms))";
+
+println "";
+
+print "CMS * CMS... ";
+reportOn (time \(mat.product cms cms));
+
+print "CMS * RMS... ";
+reportOn (time \(mat.product cms rms));
+
+print "RMS * CMS... ";
+reportOn (time \(mat.product rms cms));
+
+print "RMS * RMS... ";
+reportOn (time \(mat.product rms rms));
+
 
 ();
--- a/yetilab/matrix/test/test_matrix.yeti	Mon May 20 22:17:19 2013 +0100
+++ b/yetilab/matrix/test/test_matrix.yeti	Tue May 21 16:00:21 2013 +0100
@@ -249,6 +249,30 @@
            (newMatrix (ColumnMajor ()) [[58,139],[64,154]])
 ),
 
+"sparseProduct-\(name)": \(
+    s = mat.newSparseMatrix (ColumnMajor ()) { rows = 2, columns = 3 } [
+        { i = 0, j = 0, v = 1 },
+        { i = 0, j = 2, v = 2 },
+        { i = 1, j = 1, v = 4 },
+    ];
+    t = mat.newSparseMatrix (ColumnMajor ()) { rows = 3, columns = 2 } [
+        { i = 0, j = 1, v = 7 },
+        { i = 1, j = 0, v = 5 },
+        { i = 2, j = 0, v = 6 },
+    ];
+    prod = mat.product s t;
+    mat.isSparse? prod and
+        compareMatrices prod (mat.product (mat.toDense s) t) and
+        compareMatrices prod (mat.product (mat.toDense s) (mat.toDense t)) and
+        compareMatrices prod (mat.product s (mat.toDense t)) and
+        compareMatrices prod 
+           (mat.newSparseMatrix (RowMajor ()) { rows = 2, columns = 2 } [
+               { i = 0, j = 0, v = 12 },
+               { i = 0, j = 1, v = 7 },
+               { i = 1, j = 0, v = 20 },
+            ])
+),
+
 "productFail-\(name)": \(
     try
       \() (mat.product (constMatrix 2 { rows = 4, columns = 2 })
--- a/yetilab/vector/vector.yeti	Mon May 20 22:17:19 2013 +0100
+++ b/yetilab/vector/vector.yeti	Tue May 21 16:00:21 2013 +0100
@@ -40,6 +40,7 @@
 empty?' =
     empty? . list';
 
+//!!! this is reversed from std.at, fix
 at' n v is number -> ~double[] -> number =
     v[n];