view yetilab/matrix/matrix.yeti @ 222:77c6a81c577f matrix_opaque_immutable

Move block directory -> vector
author Chris Cannam
date Sat, 11 May 2013 15:58:36 +0100
parents 709fba377099
children c00d8f7e2708
line wrap: on
line source

module yetilab.matrix.matrix;

// A matrix is an array of vectors.

// A matrix can be stored in either column-major (the default) or
// row-major format. Storage order is an efficiency concern only:
// every API function operating on matrix objects will return the same
// result regardless of storage order.  (The transpose function just
// switches the row/column order without moving the elements.)

//!!! check that we are not unnecessarily copying in the transform functions

vec = load yetilab.vector.vector;
bf = load yetilab.vector.blockfuncs;

load yetilab.vector.vectortype;
load yetilab.matrix.matrixtype;

size m =
    case m of
    RowM r:
        major = length r;
        { 
            rows = major, 
            columns = if major > 0 then vec.length r[0] else 0 fi,
        };
    ColM c:
        major = length c;
        { 
            rows = if major > 0 then vec.length c[0] else 0 fi,
            columns = major, 
        };
    esac;

width m = (size m).columns;
height m = (size m).rows;

getAt row col m =
    case m of
    RowM rows: r = rows[row]; vec.at col r;
    ColM cols: c = cols[col]; vec.at row c;
    esac;

getColumn j m =
    case m of
    RowM rows: vec.fromList (map do i: getAt i j m done [0..length rows-1]);
    ColM cols: cols[j];
    esac;

getRow i m =
    case m of
    RowM rows: rows[i];
    ColM cols: vec.fromList (map do j: getAt i j m done [0..length cols-1]);
    esac;

/*
setAt row col n m = //!!! dangerous, could modify copies -- should it be allowed?
    case m of
    RowM rows: r = rows[row]; (vec.data r)[col] := n;
    ColM cols: c = cols[col]; (vec.data c)[row] := n;
    esac;
*/

isRowMajor? m =
    case m of
    RowM _: true;
    ColM _: false;
    esac;

newColMajorStorage { rows, columns } = 
    if rows < 1 then array []
    else array (map \(vec.zeros rows) [1..columns])
    fi;

zeroMatrix { rows, columns } = 
    ColM (newColMajorStorage { rows, columns });

zeroMatrixWithTypeOf m { rows, columns } = 
    if isRowMajor? m then
        RowM (newColMajorStorage { rows = columns, columns = rows });
    else
        ColM (newColMajorStorage { rows, columns });
    fi;

zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };

generate f { rows, columns } =
    if rows < 1 or columns < 1 then zeroSizeMatrix ()
    else
        m = array (map \(new double[rows]) [1..columns]);
        for [0..columns-1] do col:
            for [0..rows-1] do row:
                m[col][row] := f row col;
            done;
        done;
        ColM (array (map vec.vector m))
    fi;

constMatrix n = generate do row col: n done;
randomMatrix = generate do row col: Math#random() done;
identityMatrix = constMatrix 1;

transposed m =
    case m of
    RowM d: ColM d;
    ColM d: RowM d;
    esac;

flipped m =
    if isRowMajor? m then
        generate do row col: getAt row col m done (size m);
    else
        transposed
           (generate do row col: getAt col row m done
            { rows = (width m), columns = (height m) });
    fi;

toRowMajor m =
    if isRowMajor? m then m else flipped m fi;

toColumnMajor m =
    if not isRowMajor? m then m else flipped m fi;

// Matrices with different storage order but the same contents are
// equal (but comparing them is slow)
equal m1 m2 =
    if size m1 != size m2 then false
    elif isRowMajor? m1 != isRowMajor? m2 then equal (flipped m1) m2;
    else
        compare d1 d2 = all id (map2 vec.equal d1 d2);
        case m1 of
        RowM d1: case m2 of RowM d2: compare d1 d2; _: false; esac;
        ColM d1: case m2 of ColM d2: compare d1 d2; _: false; esac;
        esac
    fi;

/*!!! not needed now it's immutable?
copyOf m =
   (copyOfData d = (array (map vec.copyOf d));
    case m of
    RowM d: RowM (copyOfData d);
    ColM d: ColM (copyOfData d);
    esac);
*/

newMatrix type data = //!!! NB does not copy data
   (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac;
    if empty? data or vec.empty? (head data)
    then zeroSizeMatrix ()
    else tagger (array data)
    fi);

newRowVector data = //!!! NB does not copy data
    RowM (array [data]);

newColumnVector data = //!!! NB does not copy data
    ColM (array [data]);

scaled factor m = //!!! v inefficient
    generate do row col: factor * (getAt row col m) done (size m);

sum' m1 m2 =
    if (size m1) != (size m2)
    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
    else
        generate do row col: getAt row col m1 + getAt row col m2 done (size m1);
    fi;

product m1 m2 =
    if (size m1).columns != (size m2).rows
    then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns != (size m2).rows)";
    else
        generate do row col:
            bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
        done { rows = (size m1).rows, columns = (size m2).columns }
    fi;

asRows m =
    map do i: getRow i m done [0 .. (height m) - 1];

asColumns m =
    map do i: getColumn i m done [0 .. (width m) - 1];

concatAgainstGrain tagger getter counter mm =
   (n = counter (size (head mm));
    tagger (array
       (map do i:
           vec.concat (map (getter i) mm)
           done [0..n-1])));

concatWithGrain tagger getter counter mm =
    tagger (array
       (concat
           (map do m:
               n = counter (size m);
               map do i: getter i m done [0..n-1]
               done mm)));

checkDimensionsFor direction first mm =
   (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
    n = counter (size first);
    if not (all id (map do m: counter (size m) == n done mm)) then
        failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
    fi);

concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
    //!!! would this be better as separate concatHorizontal/concatVertical functions?
    case mm of
    first::rest: 
        checkDimensionsFor direction first mm;
        row = isRowMajor? first;
        // horizontal, row-major: against grain with rows
        // horizontal, col-major: with grain with cols
        // vertical, row-major: with grain with rows
        // vertical, col-major: against grain with cols
        case direction of
        Horizontal ():
            if row then concatAgainstGrain RowM getRow (.rows) mm;
            else concatWithGrain ColM getColumn (.columns) mm;
            fi;
        Vertical ():
            if row then concatWithGrain RowM getRow (.rows) mm;
            else concatAgainstGrain ColM getColumn (.columns) mm;
            fi;
        esac;
    [single]: single;
    _: zeroSizeMatrix ();
    esac;

rowSlice start count m = //!!! doc: storage order same as input
    if isRowMajor? m then
        RowM (array (map ((flip getRow) m) [start .. start + count - 1]))
    else 
        ColM (array (map (vec.rangeOf start count) (asColumns m)))
    fi;

columnSlice start count m = //!!! doc: storage order same as input
    if not isRowMajor? m then
        ColM (array (map ((flip getColumn) m) [start .. start + count - 1]))
    else 
        RowM (array (map (vec.rangeOf start count) (asRows m)))
    fi;

resizedTo newsize m =
   (if newsize == (size m) then
        m
    elif (height m) == 0 or (width m) == 0 then
        zeroMatrixWithTypeOf m newsize;
    else
        growrows = newsize.rows - (height m);
        growcols = newsize.columns - (width m);
        rowm = isRowMajor? m;
        resizedTo newsize
            if rowm and growrows < 0 then
                rowSlice 0 newsize.rows m
            elif (not rowm) and growcols < 0 then 
                columnSlice 0 newsize.columns m
            elif growrows < 0 then 
                rowSlice 0 newsize.rows m
            elif growcols < 0 then 
                columnSlice 0 newsize.columns m
            else
                if growrows > 0 then
                    concat (Vertical ())
                       [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
                else
                    concat (Horizontal ())
                       [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
                fi
            fi
    fi);

{
    size,
    width,
    height,
    getAt,
    getColumn,
    getRow,
//    setAt,
    isRowMajor?,
    generate,
    constMatrix,
    randomMatrix,
    zeroMatrix,
    identityMatrix,
    zeroSizeMatrix,
    equal,
//    copyOf,
    transposed,
    flipped,
    toRowMajor,
    toColumnMajor,
    scaled,
    resizedTo,
    asRows,
    asColumns,
    sum = sum',
    product,
    concat,
    rowSlice,
    columnSlice,
    newMatrix,
    newRowVector,
    newColumnVector,
}
as
{
//!!! check whether these are right to be .selector rather than just selector

    size is matrix -> { .rows is number, .columns is number },
    width is matrix -> number,
    height is matrix -> number,
    getAt is number -> number -> matrix -> number,
    getColumn is number -> matrix -> vector,
    getRow is number -> matrix -> vector,
//    setAt is number -> number -> number -> matrix -> (), //!!! lose?
    isRowMajor? is matrix -> boolean,
    generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
    constMatrix is number -> { .rows is number, .columns is number } -> matrix,
    randomMatrix is { .rows is number, .columns is number } -> matrix,
    zeroMatrix is { .rows is number, .columns is number } -> matrix, 
    identityMatrix is { .rows is number, .columns is number } -> matrix, 
    zeroSizeMatrix is () -> matrix,
    equal is matrix -> matrix -> boolean,
//    copyOf is matrix -> matrix,
    transposed is matrix -> matrix,
    flipped is matrix -> matrix, 
    toRowMajor is matrix -> matrix, 
    toColumnMajor is matrix -> matrix,
    scaled is number -> matrix -> matrix,
    resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
    asRows is matrix -> list<vector>, 
    asColumns is matrix -> list<vector>,
    sum is matrix -> matrix -> matrix,
    product is matrix -> matrix -> matrix,
    concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
    rowSlice is number -> number -> matrix -> matrix, 
    columnSlice is number -> number -> matrix -> matrix,
    newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix, 
    newRowVector is vector -> matrix, 
    newColumnVector is vector -> matrix,
}