changeset 98:bd135a950af7

Scaled, sum, product
author Chris Cannam
date Thu, 21 Mar 2013 10:18:18 +0000
parents d5fc902dcc3f
children 9832210dc42c
files yetilab/matrix/matrix.yeti yetilab/matrix/test/test_matrix.yeti
diffstat 2 files changed, 98 insertions(+), 35 deletions(-) [+]
line wrap: on
line diff
--- a/yetilab/matrix/matrix.yeti	Wed Mar 20 22:49:41 2013 +0000
+++ b/yetilab/matrix/matrix.yeti	Thu Mar 21 10:18:18 2013 +0000
@@ -58,17 +58,17 @@
         esac,
     };
 
-newStorage rows cols = 
+newStorage { rows, columns } = 
     if rows < 1 then array []
-    else array (map \(vec.zeros rows) [1..cols])
+    else array (map \(vec.zeros rows) [1..columns])
     fi;
 
-zeroMatrix rows cols = 
-    make (ColM (newStorage rows cols));
+zeroMatrix { rows, columns } = 
+    make (ColM (newStorage { rows, columns }));
 
-generate f rows cols =
-   (m = newStorage rows cols;
-    for [0..cols-1] do col:
+generate f { rows, columns } =
+   (m = newStorage { rows, columns };
+    for [0..columns-1] do col:
         for [0..rows-1] do row:
             m[col][row] := f row col;
         done;
@@ -82,6 +82,7 @@
 width m = m.size.columns;
 height m = m.size.rows;
 
+//!!! should matrices with the same content but different storage order be equal?
 equal m1 m2 =
    (compare d1 d2 =
         all id (map2 vec.equal d1 d2);
@@ -111,12 +112,12 @@
 //again). Is there a word for this?
 flipped m =
     if m.isRowMajor? then id else transposed fi
-       (generate do row col: m.getAt col row done m.size.columns m.size.rows);
+       (generate do row col: m.getAt col row done m.size);
 
 newMatrix type data is RowMajor () | ColumnMajor () -> list?<list?<number>> -> 'a =
    (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac;
     if empty? data or empty? (head data)
-    then zeroMatrix 0 0
+    then zeroMatrix { rows = 0, columns = 0 }
     else make (tagger (array (map vec.vector data)))
     fi);
 
@@ -126,6 +127,28 @@
 newColumnVector data = 
     newMatrix (ColumnMajor ()) [data];
 
+scaled factor m =
+    generate do row col: factor * m.getAt row col done m.size;
+
+sum' m1 m2 =
+    if m1.size != m2.size
+    then failWith "Matrices are not the same size: \(m1.size), \(m2.size)";
+    else
+        generate do row col: m1.getAt row col + m2.getAt row col done m1.size;
+    fi;
+
+product m1 m2 =
+    if m1.size.columns != m2.size.rows
+    then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)";
+    else
+    //!!! super-slow!
+        generate do row col:
+            r = block.list (m1.getRow row);
+            c = block.list (m2.getColumn col);
+            sum (map2 (*) r c);
+        done { rows = m1.size.rows, columns = m2.size.columns }
+    fi;
+
 {
 constMatrix, randomMatrix, zeroMatrix, identityMatrix,
 generate,
@@ -134,6 +157,8 @@
 copyOf,
 transposed,
 flipped,
-newMatrix, newRowVector, newColumnVector
+scaled,
+sum = sum', product,
+newMatrix, newRowVector, newColumnVector,
 }
 
--- a/yetilab/matrix/test/test_matrix.yeti	Wed Mar 20 22:49:41 2013 +0000
+++ b/yetilab/matrix/test/test_matrix.yeti	Thu Mar 21 10:18:18 2013 +0000
@@ -4,89 +4,91 @@
 mat = load yetilab.matrix.matrix;
 block = load yetilab.block.block;
 
+import yeti.lang: FailureException;
+
 { compare } = load yetilab.test.test;
 
 [
 
 "constMatrixEmpty": \(
-    m = mat.constMatrix 2 0 0;
+    m = mat.constMatrix 2 { rows = 0, columns = 0 };
     compare m.size { columns = 0, rows = 0 }
 ),
 
 "constMatrixEmpty2": \(
-    compare (mat.constMatrix 2 0 4).size { columns = 0, rows = 0 } and
-        compare (mat.constMatrix 2 4 0).size { columns = 0, rows = 0 }
+    compare (mat.constMatrix 2 { rows = 0, columns = 4 }).size { columns = 0, rows = 0 } and
+        compare (mat.constMatrix 2 { rows = 4, columns = 0 }).size { columns = 0, rows = 0 }
 ),
 
 "constMatrix": \(
-    m = mat.constMatrix 2 3 4;
+    m = mat.constMatrix 2 { rows = 3, columns = 4 };
     compare m.size { columns = 4, rows = 3 } and
         all id (map do row: compare (block.list (m.getRow row)) [2,2,2,2] done [0..2]) and
         all id (map do col: compare (block.list (m.getColumn col)) [2,2,2] done [0..3])
 ),
 
 "randomMatrixEmpty": \(
-    m = mat.randomMatrix 0 0;
+    m = mat.randomMatrix { rows = 0, columns = 0 };
     compare m.size { columns = 0, rows = 0 }
 ),
 
 "randomMatrix": \(
-    m = mat.randomMatrix 3 4;
+    m = mat.randomMatrix { rows = 3, columns = 4 };
     compare m.size { columns = 4, rows = 3 }
 ),
 
 "zeroMatrixEmpty": \(
-    m = mat.zeroMatrix 0 0;
+    m = mat.zeroMatrix { rows = 0, columns = 0 };
     compare m.size { columns = 0, rows = 0 }
 ),
 
 "zeroMatrix": \(
-    m = mat.zeroMatrix 3 4;
+    m = mat.zeroMatrix { rows = 3, columns = 4 };
     compare m.size { columns = 4, rows = 3 } and
         all id (map do row: compare (block.list (m.getRow row)) [0,0,0,0] done [0..2]) and
         all id (map do col: compare (block.list (m.getColumn col)) [0,0,0] done [0..3])
 ),
 
 "identityMatrixEmpty": \(
-    m = mat.identityMatrix 0 0;
+    m = mat.identityMatrix { rows = 0, columns = 0 };
     compare m.size { columns = 0, rows = 0 }
 ),
 
 "identityMatrix": \(
-    m = mat.identityMatrix 3 4;
+    m = mat.identityMatrix { rows = 3, columns = 4 };
     compare m.size { columns = 4, rows = 3 } and
         all id (map do row: compare (block.list (m.getRow row)) [1,1,1,1] done [0..2]) and
         all id (map do col: compare (block.list (m.getColumn col)) [1,1,1] done [0..3])
 ),
 
 "generateEmpty": \(
-    m = mat.generate do row col: 0 done 0 0;
+    m = mat.generate do row col: 0 done { rows = 0, columns = 0 };
     compare m.size { columns = 0, rows = 0 }
 ),
 
 "generate": \(
-    m = mat.generate do row col: row * 10 + col done 2 3;
+    m = mat.generate do row col: row * 10 + col done { rows = 2, columns = 3 };
     compare (block.list (m.getRow 0)) [0,1,2] and
         compare (block.list (m.getRow 1)) [10,11,12]
 ),
 
 "widthAndHeight": \(
-    m = mat.constMatrix 2 3 4;
+    m = mat.constMatrix 2 { rows = 3, columns = 4 };
     compare m.size { columns = mat.width m, rows = mat.height m }
 ),
 
 "equal": \(
-    m = mat.constMatrix 2 3 4;
+    m = mat.constMatrix 2 { rows = 3, columns = 4 };
     m' = m;
-    p = mat.constMatrix 2 4 3;
-    q = mat.constMatrix 3 3 4;
+    p = mat.constMatrix 2 { rows = 4, columns = 3 };
+    q = mat.constMatrix 3 { rows = 3, columns = 4 };
     mat.equal m m' and mat.equal m m and
        not mat.equal m p and not mat.equal m q and not mat.equal p q
 ),
 
 "getAt": \(
     generator row col = row * 10 + col;
-    m = mat.generate generator 2 3;
+    m = mat.generate generator { rows = 2, columns = 3 };
     all id
        (map do row: all id
            (map do col: m.getAt row col == generator row col done [0..2])
@@ -95,7 +97,7 @@
 
 "setAt": \(
     generator row col = row * 10 + col;
-    m = mat.generate generator 2 3;
+    m = mat.generate generator { rows = 2, columns = 3 };
     m.setAt 1 2 16;
     compare (m.getAt 1 2) 16 and
         compare (m.getAt 1 1) 11 and
@@ -103,13 +105,13 @@
 ),
 
 "copyOfEqual": \(
-    m = mat.constMatrix 2 3 4;
+    m = mat.constMatrix 2 { rows = 3, columns = 4 };
     m'' = mat.copyOf m;
     mat.equal m m''
 ),
 
 "copyOfAlias": \(
-    m = mat.constMatrix 2 3 4;
+    m = mat.constMatrix 2 { rows = 3, columns = 4 };
     m' = m;
     m'' = mat.copyOf m;
     m.setAt 0 0 6;
@@ -117,18 +119,18 @@
 ),
 
 "transposedEmpty": \(
-    compare (mat.transposed (mat.constMatrix 2 0 0)).size { columns = 0, rows = 0 } and
-        compare (mat.transposed (mat.constMatrix 2 0 4)).size { columns = 0, rows = 0 } and
-        compare (mat.transposed (mat.constMatrix 2 4 0)).size { columns = 0, rows = 0 }
+    compare (mat.transposed (mat.constMatrix 2 { rows = 0, columns = 0 })).size { columns = 0, rows = 0 } and
+        compare (mat.transposed (mat.constMatrix 2 { rows = 0, columns = 4 })).size { columns = 0, rows = 0 } and
+        compare (mat.transposed (mat.constMatrix 2 { rows = 4, columns = 0 })).size { columns = 0, rows = 0 }
 ),
 
 "transposedSize": \(
-    compare (mat.transposed (mat.constMatrix 2 3 4)).size { columns = 3, rows = 4 }
+    compare (mat.transposed (mat.constMatrix 2 { rows = 3, columns = 4 })).size { columns = 3, rows = 4 }
 ),
 
 "transposed": \(
     generator row col = row * 10 + col;
-    m = mat.generate generator 2 3;
+    m = mat.generate generator { rows = 2, columns = 3 };
     m' = mat.transposed m;
     all id
        (map do row: all id
@@ -137,5 +139,41 @@
             done [0..1])
 ),
 
+"scaled": \(
+    mat.equal
+       (mat.scaled 0.5 (mat.constMatrix 2 { rows = 3, columns = 4 }))
+       (mat.constMatrix 1 { rows = 3, columns = 4 }) and
+       mat.equal
+          (mat.scaled 0.5 (mat.constMatrix (-3) { rows = 3, columns = 4 }))
+          (mat.constMatrix (-1.5) { rows = 3, columns = 4 }) and
+       mat.equal
+          (mat.scaled 0.5 (mat.constMatrix 2 { rows = 0, columns = 2 }))
+          (mat.constMatrix 5 { rows = 0, columns = 0 })
+),
+
+"sum": \(
+    mat.equal
+       (mat.sum (mat.constMatrix 2 { rows = 3, columns = 4 })
+                (mat.constMatrix 1 { rows = 3, columns = 4 }))
+       (mat.constMatrix 3 { rows = 3, columns = 4 })
+),
+
+"sumFail": \(
+    try 
+        \() (mat.sum (mat.constMatrix 2 { rows = 3, columns = 4 })
+                     (mat.constMatrix 1 { rows = 3, columns = 5 }));
+        false;
+    catch FailureException e:
+        true
+    yrt
+),
+
+"product": \(
+    mat.equal
+       (mat.product (mat.constMatrix 2 { rows = 4, columns = 2 })
+                    (mat.constMatrix 3 { rows = 2, columns = 3 }))
+       (mat.constMatrix 12 { rows = 4, columns = 3 })
+),
+
 ] is hash<string, () -> boolean>;