Mercurial > hg > may
changeset 238:0c86d9284f20 sparse
Implement sparse matrix construction, add tests for sparse matrices (currently failing)
author | Chris Cannam |
---|---|
date | Mon, 20 May 2013 14:18:14 +0100 |
parents | 601dbfcf949d |
children | 741784624bb6 |
files | yetilab/matrix/matrix.yeti yetilab/matrix/test/test_matrix.yeti |
diffstat | 2 files changed, 111 insertions(+), 24 deletions(-) [+] |
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti Sun May 19 22:21:47 2013 +0100 +++ b/yetilab/matrix/matrix.yeti Mon May 20 14:18:14 2013 +0100 @@ -159,14 +159,71 @@ _: []; esac); -makeSparse type data = -//!!! not implemented - case type of - RowMajor (): - SparseCSR { values = vec.zeros 0, indices = array [], pointers = array [], extent = 0 }; - ColMajor (): - SparseCSC { values = vec.zeros 0, indices = array [], pointers = array [], extent = 0 }; - esac; +makeSparse type size data = + (isRow = case type of RowMajor (): true; ColMajor (): false esac; + ordered = + sortBy do a b: + if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi + done + (map + if isRow then + do { i, j, v }: { maj = i, min = j, v } done; + else + do { i, j, v }: { maj = j, min = i, v } done; + fi + data); + tagger = if isRow then SparseCSR else SparseCSC fi; + majorSize = if isRow then size.rows else size.columns fi; + minorSize = if isRow then size.columns else size.rows fi; + majorPointers acc nn n i data = + if n < nn then + case data of + d::rest: + majorPointers (acc ++ (map \(i) [n..d-1])) nn d (i+1) rest; + _: + majorPointers (acc ++ [i]) nn (n+1) i []; + esac; + else + acc + fi; + tagger { + values = vec.fromList (map (.v) ordered), + indices = array (map (.min) ordered), + pointers = array (majorPointers [] majorSize 0 0 (map (.maj) ordered)), + extent = minorSize, + }); + +toSparse threshold m = + if isSparse? m then m + else + { rows, columns } = size m; + enumerate threshold m ii jj = + case ii of + i::irest: + case jj of + j::rest: + v = getAt i j m; + if abs v > threshold then + { i, j, v } :. \(enumerate threshold m ii rest) + else enumerate threshold m ii rest + fi; + _: enumerate threshold m irest [0..columns-1]; + esac; + _: []; + esac; + makeSparse + if isRowMajor? m then RowMajor () else ColMajor () fi + (size m) + (enumerate threshold m [0..rows-1] [0..columns-1]); + fi; + +toDense m = + if not (isSparse? m) then m + elif isRowMajor? m then + DenseRows (array (map do row: getRow row m done [0..height m - 1])); + else + DenseCols (array (map do col: getColumn col m done [0..width m - 1])); + fi; constMatrix n = generate do row col: n done; randomMatrix = generate do row col: Math#random() done; @@ -183,9 +240,9 @@ flipped m = if isSparse? m then if isRowMajor? m then - makeSparse (ColMajor ()) (enumerateSparse m) + makeSparse (ColMajor ()) (size m) (enumerateSparse m) else - makeSparse (RowMajor ()) (enumerateSparse m) + makeSparse (RowMajor ()) (size m) (enumerateSparse m) fi else if isRowMajor? m then @@ -203,30 +260,52 @@ toColumnMajor m = if not isRowMajor? m then m else flipped m fi; -equal' vecComparator m1 m2 = +equal'' comparator vecComparator m1 m2 = + // Prerequisite: m1 and m2 have same sparse-p and storage order + (compareLists l1 l2 = all id (map2 comparator l1 l2); + compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2); + compareSparse d1 d2 = + d1.extent == d2.extent and + vecComparator d1.values d2.values and + compareLists d1.indices d2.indices and + compareLists d1.pointers d2.pointers; + case m1 of + DenseRows d1: + case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac; + DenseCols d1: + case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac; + SparseCSR d1: + case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac; + SparseCSC d1: + case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac; + esac); + +equal' comparator vecComparator m1 m2 = if size m1 != size m2 then false elif isRowMajor? m1 != isRowMajor? m2 then - equal' vecComparator (flipped m1) m2; + equal' comparator vecComparator (flipped m1) m2; + elif isSparse? m1 != isSparse? m2 then + if isSparse? m1 then + equal' comparator vecComparator m1 (toSparse 0 m2) + else + equal' comparator vecComparator (toSparse 0 m1) m2 + fi else - compare d1 d2 = all id (map2 vecComparator d1 d2); - case m1 of - DenseRows d1: case m2 of DenseRows d2: compare d1 d2; _: false; esac; - DenseCols d1: case m2 of DenseCols d2: compare d1 d2; _: false; esac; - esac + equal'' comparator vecComparator m1 m2 fi; // Compare matrices using the given comparator for individual cells. // Note that matrices with different storage order but the same // contents are equal, although comparing them is slow. equalUnder comparator = - equal' (vec.equalUnder comparator); + equal' comparator (vec.equalUnder comparator); equal = - equal' vec.equal; + equal' (==) vec.equal; newMatrix type data = //!!! NB does not copy data - (tagger = case type of RowMajor (): RowM; ColumnMajor (): DenseCols esac; + (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac; if empty? data or vec.empty? (head data) then zeroSizeMatrix () else tagger (array data) @@ -383,6 +462,8 @@ flipped, toRowMajor, toColumnMajor, + toSparse, + toDense, scaled, resizedTo, asRows, @@ -422,6 +503,8 @@ flipped is matrix -> matrix, toRowMajor is matrix -> matrix, toColumnMajor is matrix -> matrix, + toSparse is number -> matrix -> matrix, + toDense is matrix -> matrix, scaled is number -> matrix -> matrix, resizedTo is { .rows is number, .columns is number } -> matrix -> matrix, asRows is matrix -> list<vector>,
--- a/yetilab/matrix/test/test_matrix.yeti Sun May 19 22:21:47 2013 +0100 +++ b/yetilab/matrix/test/test_matrix.yeti Mon May 20 14:18:14 2013 +0100 @@ -373,12 +373,16 @@ ]); -colhash = makeTests "column-major" id; -rowhash = makeTests "row-major" mat.flipped; +colhash = makeTests "column-dense" id; +rowhash = makeTests "row-dense" mat.flipped; +sparsecolhash = makeTests "column-sparse" (mat.toSparse 0); +sparserowhash = makeTests "row-sparse" ((mat.toSparse 0) . (mat.flipped)); +//sparserowhash2 = makeTests "row-sparse2" ((mat.flipped) . (mat.toSparse 0)); all = [:]; -for (keys colhash) do k: all[k] := colhash[k] done; -for (keys rowhash) do k: all[k] := rowhash[k] done; +for [ colhash, rowhash, sparsecolhash, sparserowhash /*, sparserowhash2 */ ] do h: + for (keys h) do k: all[k] := h[k] done; +done; all is hash<string, () -> boolean>;