view yetilab/matrix/matrix.yeti @ 238:0c86d9284f20 sparse

Implement sparse matrix construction, add tests for sparse matrices (currently failing)
author Chris Cannam
date Mon, 20 May 2013 14:18:14 +0100
parents 601dbfcf949d
children 741784624bb6
line wrap: on
line source

module yetilab.matrix.matrix;

// A matrix is an array of vectors.

// A matrix can be stored in either column-major (the default) or
// row-major format. Storage order is an efficiency concern only:
// every API function operating on matrix objects will return the same
// result regardless of storage order.  (The transpose function just
// switches the row/column order without moving the elements.)

//!!! check that we are not unnecessarily copying in the transform functions

vec = load yetilab.vector.vector;
bf = load yetilab.vector.blockfuncs;

load yetilab.vector.vectortype;
load yetilab.matrix.matrixtype;

size m =
    case m of
    DenseRows r:
        major = length r;
        { 
            rows = major, 
            columns = if major > 0 then vec.length r[0] else 0 fi,
        };
    DenseCols c:
        major = length c;
        { 
            rows = if major > 0 then vec.length c[0] else 0 fi,
            columns = major, 
        };
    SparseCSR { values, indices, pointers, extent }:
        {
            rows = (length pointers) - 1,
            columns = extent
        };
    SparseCSC { values, indices, pointers, extent }:
        {
            rows = extent,
            columns = (length pointers) - 1
        };
    esac;

width m = (size m).columns;
height m = (size m).rows;

sparseSlice n d =
   (start = d.pointers[n];
    end = d.pointers[n+1];
    { 
        values = vec.slice d.values start end,
        indices = slice d.indices start d.pointers[n+1],
    });

fromSlice n m d =
   (slice = sparseSlice n d;
    var v = 0;
    for [0..length slice.indices - 1] do i:
        if slice.indices[i] == m then
            v := vec.at i slice.values;
        fi
    done;
    v);

filledSlice n d =
   (slice = sparseSlice n d;
    dslice = new double[d.extent];
    \() (map2 do v i:
             dslice[i] := v
         done (vec.list slice.values) (list slice.indices));
    vec.vector dslice);

getAt row col m =
    case m of
    DenseRows rows: r = rows[row]; vec.at col r;
    DenseCols cols: c = cols[col]; vec.at row c;
    SparseCSR data: fromSlice row col data;
    SparseCSC data: fromSlice col row data;
    esac;

getColumn j m =
    case m of
    DenseCols cols: cols[j];
    SparseCSC data: filledSlice j data;
    _: vec.fromList (map do i: getAt i j m done [0..height m - 1]);
    esac;

getRow i m =
    case m of
    DenseRows rows: rows[i];
    SparseCSR data: filledSlice i data; 
    _: vec.fromList (map do j: getAt i j m done [0..width m - 1]);
    esac;

isRowMajor? m =
    case m of
    DenseRows _: true;
    DenseCols _: false;
    SparseCSR _: true;
    SparseCSC _: false;
    esac;

isSparse? m =
    case m of
    DenseRows _: false;
    DenseCols _: false;
    SparseCSR _: true;
    SparseCSC _: true;
    esac;

newColMajorStorage { rows, columns } = 
    if rows < 1 then array []
    else array (map \(vec.zeros rows) [1..columns])
    fi;

zeroMatrix { rows, columns } = 
    DenseCols (newColMajorStorage { rows, columns });

zeroMatrixWithTypeOf m { rows, columns } = 
    if isRowMajor? m then
        DenseRows (newColMajorStorage { rows = columns, columns = rows });
    else
        DenseCols (newColMajorStorage { rows, columns });
    fi;

zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };

generate f { rows, columns } =
    if rows < 1 or columns < 1 then zeroSizeMatrix ()
    else
        m = array (map \(new double[rows]) [1..columns]);
        for [0..columns-1] do col:
            for [0..rows-1] do row:
                m[col][row] := f row col;
            done;
        done;
        DenseCols (array (map vec.vector m))
    fi;

enumerateSparse m =
   (enumerate vv ix ptr jn n =
        case vv of
        v::rest:
            { v, i = ix[n], j = jn } :.
                if n + 1 >= head ptr
                then \(enumerate rest ix (tail ptr) (jn + 1) (n + 1))
                else \(enumerate rest ix ptr jn (n + 1))
                fi;
         _: [];
        esac;
    case m of
    SparseCSC { values, indices, pointers, extent }: 
        enumerate (vec.list values) indices (list pointers) 0 0;
    SparseCSR { values, indices, pointers, extent }:
        map do { i, j, v }: { i = j, j = i, v } done 
           (enumerate (vec.list values) indices (list pointers) 0 0);
     _: [];
    esac);

makeSparse type size data =
   (isRow = case type of RowMajor (): true; ColMajor (): false esac;
    ordered = 
        sortBy do a b:
            if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
        done
           (map
                if isRow then
                    do { i, j, v }: { maj = i, min = j, v } done;
                else
                    do { i, j, v }: { maj = j, min = i, v } done;
                fi
                data);
    tagger = if isRow then SparseCSR else SparseCSC fi;
    majorSize = if isRow then size.rows else size.columns fi;
    minorSize = if isRow then size.columns else size.rows fi;
    majorPointers acc nn n i data =
        if n < nn then 
            case data of 
            d::rest:
                majorPointers (acc ++ (map \(i) [n..d-1])) nn d (i+1) rest;
             _: 
                majorPointers (acc ++ [i]) nn (n+1) i [];
            esac;
        else
            acc
        fi;
    tagger {
        values = vec.fromList (map (.v) ordered),
        indices = array (map (.min) ordered),
        pointers = array (majorPointers [] majorSize 0 0 (map (.maj) ordered)),
        extent = minorSize,
    });

toSparse threshold m =
    if isSparse? m then m
    else
        { rows, columns } = size m;
        enumerate threshold m ii jj =
            case ii of
            i::irest:
                case jj of
                j::rest:
                    v = getAt i j m;
                    if abs v > threshold then
                        { i, j, v } :. \(enumerate threshold m ii rest)
                    else enumerate threshold m ii rest
                    fi;
                 _: enumerate threshold m irest [0..columns-1];
                esac;
             _: [];
            esac;
        makeSparse 
            if isRowMajor? m then RowMajor () else ColMajor () fi
               (size m)
               (enumerate threshold m [0..rows-1] [0..columns-1]);
    fi;

toDense m =
    if not (isSparse? m) then m
    elif isRowMajor? m then
        DenseRows (array (map do row: getRow row m done [0..height m - 1]));
    else
        DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
    fi;

constMatrix n = generate do row col: n done;
randomMatrix = generate do row col: Math#random() done;
identityMatrix = constMatrix 1;

transposed m =
    case m of
    DenseRows d: DenseCols d;
    DenseCols d: DenseRows d;
    SparseCSR d: SparseCSC d;
    SparseCSC d: SparseCSR d;
    esac;

flipped m =
    if isSparse? m then
        if isRowMajor? m then
            makeSparse (ColMajor ()) (size m) (enumerateSparse m)
        else
            makeSparse (RowMajor ()) (size m) (enumerateSparse m)
        fi
    else
        if isRowMajor? m then
            generate do row col: getAt row col m done (size m);
        else
            transposed
               (generate do row col: getAt col row m done
                { rows = (width m), columns = (height m) });
        fi
    fi;

toRowMajor m =
    if isRowMajor? m then m else flipped m fi;

toColumnMajor m =
    if not isRowMajor? m then m else flipped m fi;

equal'' comparator vecComparator m1 m2 =
    // Prerequisite: m1 and m2 have same sparse-p and storage order
   (compareLists l1 l2 = all id (map2 comparator l1 l2);
    compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
    compareSparse d1 d2 =
        d1.extent == d2.extent and
        vecComparator d1.values d2.values and
        compareLists d1.indices d2.indices and
        compareLists d1.pointers d2.pointers;
    case m1 of
    DenseRows d1:
        case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
    DenseCols d1:
        case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
    SparseCSR d1:
        case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
    SparseCSC d1:
        case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
    esac);

equal' comparator vecComparator m1 m2 =
    if size m1 != size m2 then 
        false
    elif isRowMajor? m1 != isRowMajor? m2 then
        equal' comparator vecComparator (flipped m1) m2;
    elif isSparse? m1 != isSparse? m2 then
        if isSparse? m1 then
            equal' comparator vecComparator m1 (toSparse 0 m2)
        else
            equal' comparator vecComparator (toSparse 0 m1) m2
        fi
    else
        equal'' comparator vecComparator m1 m2
    fi;

// Compare matrices using the given comparator for individual cells.
// Note that matrices with different storage order but the same
// contents are equal, although comparing them is slow.
equalUnder comparator =
    equal' comparator (vec.equalUnder comparator);

equal =
    equal' (==) vec.equal;

newMatrix type data = //!!! NB does not copy data
   (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
    if empty? data or vec.empty? (head data)
    then zeroSizeMatrix ()
    else tagger (array data)
    fi);

newRowVector data = //!!! NB does not copy data
    DenseRows (array [data]);

newColumnVector data = //!!! NB does not copy data
    DenseCols (array [data]);

scaled factor m = //!!! v inefficient
    generate do row col: factor * (getAt row col m) done (size m);

sum' m1 m2 =
    if (size m1) != (size m2)
    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
    else
        generate do row col: getAt row col m1 + getAt row col m2 done (size m1);
    fi;

difference m1 m2 = //!!! doc: m1 - m2, not m2 - m1
    if (size m1) != (size m2)
    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
    else
        generate do row col: getAt row col m1 - getAt row col m2 done (size m1);
    fi;

abs' m =
    generate do row col: abs (getAt row col m) done (size m);

product m1 m2 =
    if (size m1).columns != (size m2).rows
    then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns != (size m2).rows)";
    else
        generate do row col:
            bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
        done { rows = (size m1).rows, columns = (size m2).columns }
    fi;

asRows m =
    map do i: getRow i m done [0 .. (height m) - 1];

asColumns m =
    map do i: getColumn i m done [0 .. (width m) - 1];

concatAgainstGrain tagger getter counter mm =
   (n = counter (size (head mm));
    tagger (array
       (map do i:
           vec.concat (map (getter i) mm)
           done [0..n-1])));

concatWithGrain tagger getter counter mm =
    tagger (array
       (concat
           (map do m:
               n = counter (size m);
               map do i: getter i m done [0..n-1]
               done mm)));

checkDimensionsFor direction first mm =
   (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
    n = counter (size first);
    if not (all id (map do m: counter (size m) == n done mm)) then
        failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
    fi);

concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
    //!!! would this be better as separate concatHorizontal/concatVertical functions?
    case mm of
    first::rest: 
        checkDimensionsFor direction first mm;
        row = isRowMajor? first;
        // horizontal, row-major: against grain with rows
        // horizontal, col-major: with grain with cols
        // vertical, row-major: with grain with rows
        // vertical, col-major: against grain with cols
        case direction of
        Horizontal ():
            if row then concatAgainstGrain DenseRows getRow (.rows) mm;
            else concatWithGrain DenseCols getColumn (.columns) mm;
            fi;
        Vertical ():
            if row then concatWithGrain DenseRows getRow (.rows) mm;
            else concatAgainstGrain DenseCols getColumn (.columns) mm;
            fi;
        esac;
    [single]: single;
    _: zeroSizeMatrix ();
    esac;

rowSlice start count m = //!!! doc: storage order same as input
    if isRowMajor? m then
        DenseRows (array (map ((flip getRow) m) [start .. start + count - 1]))
    else 
        DenseCols (array (map (vec.rangeOf start count) (asColumns m)))
    fi;

columnSlice start count m = //!!! doc: storage order same as input
    if not isRowMajor? m then
        DenseCols (array (map ((flip getColumn) m) [start .. start + count - 1]))
    else 
        DenseRows (array (map (vec.rangeOf start count) (asRows m)))
    fi;

resizedTo newsize m =
   (if newsize == (size m) then
        m
    elif (height m) == 0 or (width m) == 0 then
        zeroMatrixWithTypeOf m newsize;
    else
        growrows = newsize.rows - (height m);
        growcols = newsize.columns - (width m);
        rowm = isRowMajor? m;
        resizedTo newsize
            if rowm and growrows < 0 then
                rowSlice 0 newsize.rows m
            elif (not rowm) and growcols < 0 then 
                columnSlice 0 newsize.columns m
            elif growrows < 0 then 
                rowSlice 0 newsize.rows m
            elif growcols < 0 then 
                columnSlice 0 newsize.columns m
            else
                if growrows > 0 then
                    concat (Vertical ())
                       [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
                else
                    concat (Horizontal ())
                       [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
                fi
            fi
    fi);

{
    size,
    width,
    height,
    getAt,
    getColumn,
    getRow,
    isRowMajor?,
    isSparse?,
    generate,
    constMatrix,
    randomMatrix,
    zeroMatrix,
    identityMatrix,
    zeroSizeMatrix,
    equal,
    equalUnder,
    transposed,
    flipped,
    toRowMajor,
    toColumnMajor,
    toSparse,
    toDense,
    scaled,
    resizedTo,
    asRows,
    asColumns,
    sum = sum',
    difference,
    abs = abs',
    product,
    concat,
    rowSlice,
    columnSlice,
    newMatrix,
    newRowVector,
    newColumnVector,
}
as
{
//!!! check whether these are right to be .selector rather than just selector

    size is matrix -> { .rows is number, .columns is number },
    width is matrix -> number,
    height is matrix -> number,
    getAt is number -> number -> matrix -> number,
    getColumn is number -> matrix -> vector,
    getRow is number -> matrix -> vector,
    isRowMajor? is matrix -> boolean,
    isSparse? is matrix -> boolean,
    generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
    constMatrix is number -> { .rows is number, .columns is number } -> matrix,
    randomMatrix is { .rows is number, .columns is number } -> matrix,
    zeroMatrix is { .rows is number, .columns is number } -> matrix, 
    identityMatrix is { .rows is number, .columns is number } -> matrix, 
    zeroSizeMatrix is () -> matrix,
    equal is matrix -> matrix -> boolean,
    equalUnder is (number -> number -> boolean) -> matrix -> matrix -> boolean,
    transposed is matrix -> matrix,
    flipped is matrix -> matrix, 
    toRowMajor is matrix -> matrix, 
    toColumnMajor is matrix -> matrix,
    toSparse is number -> matrix -> matrix,
    toDense is matrix -> matrix,
    scaled is number -> matrix -> matrix,
    resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
    asRows is matrix -> list<vector>, 
    asColumns is matrix -> list<vector>,
    sum is matrix -> matrix -> matrix,
    difference is matrix -> matrix -> matrix,
    abs is matrix -> matrix,
    product is matrix -> matrix -> matrix,
    concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
    rowSlice is number -> number -> matrix -> matrix, 
    columnSlice is number -> number -> matrix -> matrix,
    newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix, 
    newRowVector is vector -> matrix, 
    newColumnVector is vector -> matrix,
}