diff yetilab/matrix/matrix.yeti @ 254:5eb57c649de0 sparse

Using hashes is simpler, but turns out to be mostly no faster and sometimes much slower. Not one to merge back.
author Chris Cannam
date Tue, 21 May 2013 17:40:33 +0100
parents efdb1aee9d21
children
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Tue May 21 16:00:21 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Tue May 21 17:40:33 2013 +0100
@@ -31,32 +31,28 @@
             rows = if major > 0 then vec.length c[0] else 0 fi,
             columns = major, 
         };
-    SparseCSR { values, indices, pointers, extent }:
-        {
-            rows = (length pointers) - 1,
-            columns = extent
-        };
-    SparseCSC { values, indices, pointers, extent }:
-        {
-            rows = extent,
-            columns = (length pointers) - 1
-        };
+    SparseRows { size, data }: size;
+    SparseCols { size, data }: size;
     esac;
 
 width m = (size m).columns;
 height m = (size m).rows;
 
 nonZeroValues m =
-   (nz d =
+   (nzDense d =
         sum
            (map do v:
                 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
                 done d);
+    nzSparse d =
+        sum (map do k:
+                 length (keys d.data[k])
+                 done (keys d.data));
     case m of 
-    DenseRows d: nz d;
-    DenseCols d: nz d;
-    SparseCSR d: vec.length d.values;
-    SparseCSC d: vec.length d.values;
+    DenseRows d: nzDense d;
+    DenseCols d: nzDense d;
+    SparseRows d: nzSparse d;
+    SparseCols d: nzSparse d;
     esac);
 
 density m =
@@ -64,60 +60,36 @@
     cells = rows * columns;
     (nonZeroValues m) / cells);
 
-sparseSlice n d =
-   (start = d.pointers[n];
-    end = d.pointers[n+1];
-    { 
-        values = vec.slice d.values start end,
-        indices = slice d.indices start end,
-    });
+fromSlice n m { data } =
+    if n in data and m in data[n] then data[n][m] else 0 fi;
 
-nonEmptySlices d =
-   (ne = array [];
-    for [0..length d.pointers - 2] do i:
-        if d.pointers[i] != d.pointers[i+1] then
-            push ne i
-        fi
-    done;
-    ne);
-
-fromSlice n m d =
-   (slice = sparseSlice n d;
-    var v = 0;
-    for [0..length slice.indices - 1] do i:
-        if slice.indices[i] == m then
-            v := vec.at i slice.values;
-        fi
-    done;
-    v);
-
-filledSlice n d =
-   (slice = sparseSlice n d;
-    dslice = new double[d.extent];
-    for [0..length slice.indices - 1] do i:
-        dslice[slice.indices[i]] := vec.at i slice.values;
-    done;
-    vec.vector dslice);
+filledSlice sz n { data } =
+   (slice = new double[sz];
+    if n in data then
+        h = data[n];
+        for (keys h) do k: slice[k] := h[k] done;
+    fi;
+    vec.vector slice);
 
 getAt row col m =
     case m of
     DenseRows rows: r = rows[row]; vec.at col r;
     DenseCols cols: c = cols[col]; vec.at row c;
-    SparseCSR data: fromSlice row col data;
-    SparseCSC data: fromSlice col row data;
+    SparseRows data: fromSlice row col data;
+    SparseCols data: fromSlice col row data;
     esac;
 
 getColumn j m =
     case m of
     DenseCols cols: cols[j];
-    SparseCSC data: filledSlice j data;
+    SparseCols data: filledSlice data.size.rows j data;
     _: vec.fromList (map do i: getAt i j m done [0..height m - 1]);
     esac;
 
 getRow i m =
     case m of
     DenseRows rows: rows[i];
-    SparseCSR data: filledSlice i data; 
+    SparseRows data: filledSlice data.size.columns i data; 
     _: vec.fromList (map do j: getAt i j m done [0..width m - 1]);
     esac;
 
@@ -125,16 +97,16 @@
     case m of
     DenseRows _: true;
     DenseCols _: false;
-    SparseCSR _: true;
-    SparseCSC _: false;
+    SparseRows _: true;
+    SparseCols _: false;
     esac;
 
 isSparse? m =
     case m of
     DenseRows _: false;
     DenseCols _: false;
-    SparseCSR _: true;
-    SparseCSC _: true;
+    SparseRows _: true;
+    SparseCols _: true;
     esac;
 
 newColumnMajorStorage { rows, columns } = 
@@ -167,58 +139,33 @@
     fi;
 
 enumerateSparse m =
-   (enumerate { values, indices, pointers } =
+   (enumerate { size, data } =
         concat
            (map do i:
-                start = pointers[i];
-                end = pointers[i+1];
-                map2 do j v: { i, j, v } done 
-                    (slice indices start end)
-                    (vec.list (vec.slice values start end))
-                done [0..length pointers - 2]);
+                map do j:
+                    { i, j, v = data[i][j] }
+                    done (keys data[i])
+                done (keys data));
     case m of
-    SparseCSC d: 
+    SparseCols d: 
         map do { i, j, v }: { i = j, j = i, v } done (enumerate d);
-    SparseCSR d:
+    SparseRows d:
         enumerate d;
      _: [];
     esac);
 
-makeSparse type size data =
-   (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
-    ordered = 
-        sortBy do a b:
-            if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
-        done
-           (map
-                if isRow then
-                    do { i, j, v }: { maj = i, min = j, v } done;
-                else
-                    do { i, j, v }: { maj = j, min = i, v } done;
-                fi
-                data);
-    tagger = if isRow then SparseCSR else SparseCSC fi;
-    majorSize = if isRow then size.rows else size.columns fi;
-    minorSize = if isRow then size.columns else size.rows fi;
-    pointers = array [0];
-    setArrayCapacity pointers (size.rows + 1);
-    fillPointers n i data =
-        if n < majorSize then
-            case data of
-            d::rest:
-               (for [n..d-1] \(push pointers i);
-                fillPointers d (i+1) rest);
-             _:
-                for [n..majorSize-1] \(push pointers i);
-            esac;
+makeSparse type size entries =
+   (isRow = case type of RowMajor (): true; ColumnMajor (): false; esac;
+    data = [:];
+    preprocess = 
+        if isRow then id
+        else do { i, j, v }: { i = j, j = i, v } done
         fi;
-    fillPointers 0 0 (map (.maj) ordered);
-    tagger {
-        values = vec.fromList (map (.v) ordered),
-        indices = array (map (.min) ordered),
-        pointers,
-        extent = minorSize,
-    });
+    for (map preprocess entries) do { i, j, v }:
+        if not (i in data) then data[i] := [:] fi;
+        data[i][j] := v;
+    done;
+    if isRow then SparseRows else SparseCols fi { size, data });
 
 toSparse m =
     if isSparse? m then m
@@ -259,17 +206,31 @@
     case m of
     DenseRows d: DenseCols d;
     DenseCols d: DenseRows d;
-    SparseCSR d: SparseCSC d;
-    SparseCSC d: SparseCSR d;
+    SparseRows { data, size }: SparseCols
+        { data, size = { rows = size.columns, columns = size.rows } }; 
+    SparseCols { data, size }: SparseRows
+        { data, size = { rows = size.columns, columns = size.rows } }; 
     esac;
 
+sparseFlipped m =
+   ({ tagger, data } = 
+        case m of
+        SparseCols { data }: { tagger = SparseRows, data };
+        SparseRows { data }: { tagger = SparseCols, data };
+        _: failWith "sparseFlipped called for non-sparse matrix";
+        esac;
+    data' = [:];
+    for (keys data) do i:
+        for (keys data[i]) do j:
+            if not (j in data') then data'[j] := [:] fi;
+            data'[j][i] := data[i][j];
+        done
+    done;
+    tagger { size = size m, data = data' });
+
 flipped m =
     if isSparse? m then
-        if isRowMajor? m then
-            makeSparse (ColumnMajor ()) (size m) (enumerateSparse m)
-        else
-            makeSparse (RowMajor ()) (size m) (enumerateSparse m)
-        fi
+        sparseFlipped m
     else
         if isRowMajor? m then
             generate do row col: getAt row col m done (size m);
@@ -287,22 +248,29 @@
     if not isRowMajor? m then m else flipped m fi;
 
 equal'' comparator vecComparator m1 m2 =
-    // Prerequisite: m1 and m2 have same sparse-p and storage order
+    // Prerequisite: m1 and m2 have same size, sparse-p, and storage order
    (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
     compareSparse d1 d2 =
-        d1.extent == d2.extent and
-        vecComparator d1.values d2.values and
-        d1.indices == d2.indices and
-        d1.pointers == d2.pointers;
+       (data1 = d1.data;
+        data2 = d2.data;
+        keys data1 == keys data2 and
+            all id
+               (map do i: 
+                    keys data1[i] == keys data2[i] and
+                    all id
+                       (map do j:
+                            comparator data1[i][j] data2[i][j]
+                            done (keys data1[i]))
+                    done (keys data1)));
     case m1 of
     DenseRows d1:
         case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
     DenseCols d1:
         case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
-    SparseCSR d1:
-        case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
-    SparseCSC d1:
-        case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
+    SparseRows d1:
+        case m2 of SparseRows d2: compareSparse d1 d2; _: false; esac;
+    SparseCols d1:
+        case m2 of SparseCols d2: compareSparse d1 d2; _: false; esac;
     esac);
 
 equal' comparator vecComparator m1 m2 =
@@ -393,27 +361,21 @@
 
 sparseProduct size m1 m2 =
     case m2 of
-    SparseCSC d:
+    SparseCols d:
        (e = enumerateSparse m1;
-        entries =
-           (map do j':
-                cs = sparseSlice j' d;
-                hin = mapIntoHash
-                   (at cs.indices) ((flip vec.at) cs.values)
-                   [0..length cs.indices - 1];
-                hout = [:];
-                for e do { v, i, j }:
-                    if j in hin then
-                        p = v * hin[j];
-                        hout[i] := p + (if i in hout then hout[i] else 0 fi);
-                    fi
-                done;
-                map do i:
-                    { i, j = j', v = hout[i] }
-                done (keys hout);
-            done (nonEmptySlices d));
-        makeSparse (ColumnMajor ()) size (concat entries));
-    SparseCSR _:
+        out = [:];
+        for (keys d.data) do j':
+            h = [:];
+            for e do { v, i, j }:
+                if j in d.data[j'] then
+                    if not (i in h) then h[i] := 0 fi;
+                    h[i] := h[i] + v * d.data[j'][j];
+                fi;
+            done;
+            out[j'] := h;
+        done;
+        SparseCols { size, data = out });
+    SparseRows _:
         sparseProduct size m1 (flipped m2);
      _: failWith "sparseProduct called for non-sparse matrices";
     esac;