diff yetilab/matrix/matrix.yeti @ 223:51af10e6cd0d

Merge from matrix_opaque_immutable branch
author Chris Cannam
date Sat, 11 May 2013 16:00:58 +0100
parents 77c6a81c577f
children c00d8f7e2708
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Tue May 07 21:43:10 2013 +0100
+++ b/yetilab/matrix/matrix.yeti	Sat May 11 16:00:58 2013 +0100
@@ -1,7 +1,7 @@
 
 module yetilab.matrix.matrix;
 
-// A matrix is an array of fvectors (i.e. primitive double[]s).
+// A matrix is an array of vectors.
 
 // A matrix can be stored in either column-major (the default) or
 // row-major format. Storage order is an efficiency concern only:
@@ -9,56 +9,64 @@
 // result regardless of storage order.  (The transpose function just
 // switches the row/column order without moving the elements.)
 
-vec = load yetilab.block.fvector;
-block = load yetilab.block.block;
-bf = load yetilab.block.blockfuncs;
+//!!! check that we are not unnecessarily copying in the transform functions
 
-load yetilab.block.blocktype;
+vec = load yetilab.vector.vector;
+bf = load yetilab.vector.blockfuncs;
+
+load yetilab.vector.vectortype;
 load yetilab.matrix.matrixtype;
 
-make d = {
-    get data () = d,
-    get size () =
-        case d of
-        RowM r:
-            major = length r;
-            { 
-                rows = major, 
-                columns = if major > 0 then vec.length r[0] else 0 fi,
-            };
-        ColM c:
-            major = length c;
-            { 
-                rows = if major > 0 then vec.length c[0] else 0 fi,
-                columns = major, 
-            };
-        esac,
-    getColumn j =
-        case d of
-        RowM rows: block.fromList (map do i: getAt i j done [0..length rows-1]);
-        ColM cols: block.block cols[j];
-        esac,
-    getRow i =
-        case d of
-        RowM rows: block.block rows[i];
-        ColM cols: block.fromList (map do j: getAt i j done [0..length cols-1]);
-        esac,
-    getAt row col =
-        case d of
-        RowM rows: r = rows[row]; (r is ~double[])[col];
-        ColM cols: c = cols[col]; (c is ~double[])[row];
-        esac,
-    setAt row col n = //!!! dangerous, could modify copies -- should it be allowed?
-        case d of
-        RowM rows: r = rows[row]; (r is ~double[])[col] := n;
-        ColM cols: c = cols[col]; (c is ~double[])[row] := n;
-        esac,
-    get isRowMajor? () =
-        case d of
-        RowM _: true;
-        ColM _: false;
-        esac,
-    };
+size m =
+    case m of
+    RowM r:
+        major = length r;
+        { 
+            rows = major, 
+            columns = if major > 0 then vec.length r[0] else 0 fi,
+        };
+    ColM c:
+        major = length c;
+        { 
+            rows = if major > 0 then vec.length c[0] else 0 fi,
+            columns = major, 
+        };
+    esac;
+
+width m = (size m).columns;
+height m = (size m).rows;
+
+getAt row col m =
+    case m of
+    RowM rows: r = rows[row]; vec.at col r;
+    ColM cols: c = cols[col]; vec.at row c;
+    esac;
+
+getColumn j m =
+    case m of
+    RowM rows: vec.fromList (map do i: getAt i j m done [0..length rows-1]);
+    ColM cols: cols[j];
+    esac;
+
+getRow i m =
+    case m of
+    RowM rows: rows[i];
+    ColM cols: vec.fromList (map do j: getAt i j m done [0..length cols-1]);
+    esac;
+
+/*
+setAt row col n m = //!!! dangerous, could modify copies -- should it be allowed?
+    case m of
+    RowM rows: r = rows[row]; (vec.data r)[col] := n;
+    ColM cols: c = cols[col]; (vec.data c)[row] := n;
+    esac;
+*/
+
+isRowMajor? m =
+    case m of
+    RowM _: true;
+    ColM _: false;
+    esac;
 
 newColMajorStorage { rows, columns } = 
     if rows < 1 then array []
@@ -66,133 +74,134 @@
     fi;
 
 zeroMatrix { rows, columns } = 
-    make (ColM (newColMajorStorage { rows, columns }));
+    ColM (newColMajorStorage { rows, columns });
 
 zeroMatrixWithTypeOf m { rows, columns } = 
-    if m.isRowMajor? then
-        make (RowM (newColMajorStorage { rows = columns, columns = rows }));
+    if isRowMajor? m then
+        RowM (newColMajorStorage { rows = columns, columns = rows });
     else
-        make (ColM (newColMajorStorage { rows, columns }));
+        ColM (newColMajorStorage { rows, columns });
     fi;
 
+zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
+
 generate f { rows, columns } =
-   (m = newColMajorStorage { rows, columns };
-    for [0..columns-1] do col:
-        for [0..rows-1] do row:
-            m[col][row] := f row col;
+    if rows < 1 or columns < 1 then zeroSizeMatrix ()
+    else
+        m = array (map \(new double[rows]) [1..columns]);
+        for [0..columns-1] do col:
+            for [0..rows-1] do row:
+                m[col][row] := f row col;
+            done;
         done;
-    done;
-    make (ColM m));
+        ColM (array (map vec.vector m))
+    fi;
 
 constMatrix n = generate do row col: n done;
 randomMatrix = generate do row col: Math#random() done;
 identityMatrix = constMatrix 1;
-zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
-
-width m = m.size.columns;
-height m = m.size.rows;
 
 transposed m =
-    make
-       (case m.data of
-        RowM d: ColM d;
-        ColM d: RowM d;
-        esac);
+    case m of
+    RowM d: ColM d;
+    ColM d: RowM d;
+    esac;
 
 flipped m =
-    if m.isRowMajor? then
-        generate do row col: m.getAt row col done m.size;
+    if isRowMajor? m then
+        generate do row col: getAt row col m done (size m);
     else
         transposed
-           (generate do row col: m.getAt col row done
-            { rows = m.size.columns, columns = m.size.rows });
+           (generate do row col: getAt col row m done
+            { rows = (width m), columns = (height m) });
     fi;
 
 toRowMajor m =
-    if m.isRowMajor? then m else flipped m fi;
+    if isRowMajor? m then m else flipped m fi;
 
 toColumnMajor m =
-    if not m.isRowMajor? then m else flipped m fi;
+    if not isRowMajor? m then m else flipped m fi;
 
 // Matrices with different storage order but the same contents are
 // equal (but comparing them is slow)
 equal m1 m2 =
-    if m1.size != m2.size then false
-    elif m1.isRowMajor? != m2.isRowMajor? then equal (flipped m1) m2;
+    if size m1 != size m2 then false
+    elif isRowMajor? m1 != isRowMajor? m2 then equal (flipped m1) m2;
     else
         compare d1 d2 = all id (map2 vec.equal d1 d2);
-        case m1.data of
-        RowM d1: case m2.data of RowM d2: compare d1 d2; _: false; esac;
-        ColM d1: case m2.data of ColM d2: compare d1 d2; _: false; esac;
+        case m1 of
+        RowM d1: case m2 of RowM d2: compare d1 d2; _: false; esac;
+        ColM d1: case m2 of ColM d2: compare d1 d2; _: false; esac;
         esac
     fi;
 
+/*!!! not needed now it's immutable?
 copyOf m =
    (copyOfData d = (array (map vec.copyOf d));
-    make
-       (case m.data of
-        RowM d: RowM (copyOfData d);
-        ColM d: ColM (copyOfData d);
-        esac));
+    case m of
+    RowM d: RowM (copyOfData d);
+    ColM d: ColM (copyOfData d);
+    esac);
+*/
 
 newMatrix type data = //!!! NB does not copy data
    (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac;
-    if empty? data or block.empty? (head data)
+    if empty? data or vec.empty? (head data)
     then zeroSizeMatrix ()
-    else make (tagger (array (map block.data data)))
+    else tagger (array data)
     fi);
 
 newRowVector data = //!!! NB does not copy data
-    make (RowM (array [block.data data]));
+    RowM (array [data]);
 
 newColumnVector data = //!!! NB does not copy data
-    make (ColM (array [block.data data]));
+    ColM (array [data]);
 
 scaled factor m = //!!! v inefficient
-    generate do row col: factor * m.getAt row col done m.size;
+    generate do row col: factor * (getAt row col m) done (size m);
 
 sum' m1 m2 =
-    if m1.size != m2.size
-    then failWith "Matrices are not the same size: \(m1.size), \(m2.size)";
+    if (size m1) != (size m2)
+    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
     else
-        generate do row col: m1.getAt row col + m2.getAt row col done m1.size;
+        generate do row col: getAt row col m1 + getAt row col m2 done (size m1);
     fi;
 
 product m1 m2 =
-    if m1.size.columns != m2.size.rows
-    then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)";
+    if (size m1).columns != (size m2).rows
+    then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns != (size m2).rows)";
     else
         generate do row col:
-            bf.sum (bf.multiply (m1.getRow row) (m2.getColumn col))
-        done { rows = m1.size.rows, columns = m2.size.columns }
+            bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
+        done { rows = (size m1).rows, columns = (size m2).columns }
     fi;
 
 asRows m =
-    map m.getRow [0 .. m.size.rows - 1];
+    map do i: getRow i m done [0 .. (height m) - 1];
 
 asColumns m =
-    map m.getColumn [0 .. m.size.columns - 1];
+    map do i: getColumn i m done [0 .. (width m) - 1];
 
 concatAgainstGrain tagger getter counter mm =
-   (n = counter (head mm).size;
-    make (tagger (array
+   (n = counter (size (head mm));
+    tagger (array
        (map do i:
-           block.data (block.concat (map do m: getter m i done mm))
-           done [0..n-1]))));
+           vec.concat (map (getter i) mm)
+           done [0..n-1])));
 
 concatWithGrain tagger getter counter mm =
-    make (tagger (array
+    tagger (array
        (concat
            (map do m:
-               n = counter m.size;
-               map do i: block.data (getter m i) done [0..n-1]
-               done mm))));
+               n = counter (size m);
+               map do i: getter i m done [0..n-1]
+               done mm)));
 
 checkDimensionsFor direction first mm =
    (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
-    n = counter first.size;
-    if not (all id (map do m: counter m.size == n done mm)) then
-        failWith "Matrix dimensions incompatible for concat (found \(map do m: counter m.size done mm) not all of which are \(n))";
+    n = counter (size first);
+    if not (all id (map do m: counter (size m) == n done mm)) then
+        failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
     fi);
 
 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
@@ -200,19 +209,19 @@
     case mm of
     first::rest: 
         checkDimensionsFor direction first mm;
-        row = first.isRowMajor?;
+        row = isRowMajor? first;
         // horizontal, row-major: against grain with rows
         // horizontal, col-major: with grain with cols
         // vertical, row-major: with grain with rows
         // vertical, col-major: against grain with cols
         case direction of
         Horizontal ():
-            if row then concatAgainstGrain RowM (.getRow) (.rows) mm;
-            else concatWithGrain ColM (.getColumn) (.columns) mm;
+            if row then concatAgainstGrain RowM getRow (.rows) mm;
+            else concatWithGrain ColM getColumn (.columns) mm;
             fi;
         Vertical ():
-            if row then concatWithGrain RowM (.getRow) (.rows) mm;
-            else concatAgainstGrain ColM (.getColumn) (.columns) mm;
+            if row then concatWithGrain RowM getRow (.rows) mm;
+            else concatAgainstGrain ColM getColumn (.columns) mm;
             fi;
         esac;
     [single]: single;
@@ -220,28 +229,28 @@
     esac;
 
 rowSlice start count m = //!!! doc: storage order same as input
-    if m.isRowMajor? then
-        make (RowM (array (map (block.data . m.getRow) [start .. start + count - 1])))
+    if isRowMajor? m then
+        RowM (array (map ((flip getRow) m) [start .. start + count - 1]))
     else 
-        make (ColM (array (map (block.data . (block.rangeOf start count)) (asColumns m))))
+        ColM (array (map (vec.rangeOf start count) (asColumns m)))
     fi;
 
 columnSlice start count m = //!!! doc: storage order same as input
-    if not m.isRowMajor? then
-        make (ColM (array (map (block.data . m.getColumn) [start .. start + count - 1])))
+    if not isRowMajor? m then
+        ColM (array (map ((flip getColumn) m) [start .. start + count - 1]))
     else 
-        make (RowM (array (map (block.data . (block.rangeOf start count)) (asRows m))))
+        RowM (array (map (vec.rangeOf start count) (asRows m)))
     fi;
 
 resizedTo newsize m =
-   (if newsize == m.size then
+   (if newsize == (size m) then
         m
-    elif m.size.rows == 0 or m.size.columns == 0 then
+    elif (height m) == 0 or (width m) == 0 then
         zeroMatrixWithTypeOf m newsize;
     else
-        growrows = newsize.rows - m.size.rows;
-        growcols = newsize.columns - m.size.columns;
-        rowm = m.isRowMajor?;
+        growrows = newsize.rows - (height m);
+        growcols = newsize.columns - (width m);
+        rowm = isRowMajor? m;
         resizedTo newsize
             if rowm and growrows < 0 then
                 rowSlice 0 newsize.rows m
@@ -254,40 +263,83 @@
             else
                 if growrows > 0 then
                     concat (Vertical ())
-                       [m, zeroMatrixWithTypeOf m (m.size with { rows = growrows })]
+                       [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
                 else
                     concat (Horizontal ())
-                       [m, zeroMatrixWithTypeOf m (m.size with { columns = growcols })]
+                       [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
                 fi
             fi
     fi);
 
 {
+    size,
+    width,
+    height,
+    getAt,
+    getColumn,
+    getRow,
+//    setAt,
+    isRowMajor?,
+    generate,
+    constMatrix,
+    randomMatrix,
+    zeroMatrix,
+    identityMatrix,
+    zeroSizeMatrix,
+    equal,
+//    copyOf,
+    transposed,
+    flipped,
+    toRowMajor,
+    toColumnMajor,
+    scaled,
+    resizedTo,
+    asRows,
+    asColumns,
+    sum = sum',
+    product,
+    concat,
+    rowSlice,
+    columnSlice,
+    newMatrix,
+    newRowVector,
+    newColumnVector,
+}
+as
+{
+//!!! check whether these are right to be .selector rather than just selector
+
+    size is matrix -> { .rows is number, .columns is number },
+    width is matrix -> number,
+    height is matrix -> number,
+    getAt is number -> number -> matrix -> number,
+    getColumn is number -> matrix -> vector,
+    getRow is number -> matrix -> vector,
+//    setAt is number -> number -> number -> matrix -> (), //!!! lose?
+    isRowMajor? is matrix -> boolean,
     generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
     constMatrix is number -> { .rows is number, .columns is number } -> matrix,
     randomMatrix is { .rows is number, .columns is number } -> matrix,
     zeroMatrix is { .rows is number, .columns is number } -> matrix, 
     identityMatrix is { .rows is number, .columns is number } -> matrix, 
     zeroSizeMatrix is () -> matrix,
-    width is matrix -> number,
-    height is matrix -> number,
     equal is matrix -> matrix -> boolean,
-    copyOf is matrix -> matrix,
+//    copyOf is matrix -> matrix,
     transposed is matrix -> matrix,
     flipped is matrix -> matrix, 
     toRowMajor is matrix -> matrix, 
     toColumnMajor is matrix -> matrix,
     scaled is number -> matrix -> matrix,
     resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
-    asRows is matrix -> list<block>, 
-    asColumns is matrix -> list<block>,
-    sum is matrix -> matrix -> matrix = sum',
+    asRows is matrix -> list<vector>, 
+    asColumns is matrix -> list<vector>,
+    sum is matrix -> matrix -> matrix,
     product is matrix -> matrix -> matrix,
     concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
     rowSlice is number -> number -> matrix -> matrix, 
     columnSlice is number -> number -> matrix -> matrix,
-    newMatrix is (ColumnMajor () | RowMajor ()) -> list<block> -> matrix, 
-    newRowVector is block -> matrix, 
-    newColumnVector is block -> matrix,
+    newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix, 
+    newRowVector is vector -> matrix, 
+    newColumnVector is vector -> matrix,
 }