changeset 548:712999a0ad66

Make matrix sum also take a list (and add vector random)
author Chris Cannam
date Mon, 24 Mar 2014 11:50:05 +0000
parents 1200e1029ecb
children 0905242b46cb
files src/may/matrix.yeti src/may/matrix/complex.yeti src/may/matrix/test/speedtest.yeti src/may/matrix/test/test_matrix.yeti src/may/stream/framer.yeti src/may/stream/manipulate.yeti src/may/vector.yeti
diffstat 7 files changed, 72 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/src/may/matrix.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/matrix.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -504,16 +504,22 @@
             done (keys h));
     newSparseMatching m1 entries);
 
-sum' m1 m2 =
-    if (size m1) != (size m2)
-    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
-    elif isSparse? m1 and isSparse? m2 then
-        sparseSumOrDifference (+) m1 m2;
-    else
-        add2 v1 v2 = vec.add [v1,v2];
-        denseLinearOp add2 m1 m2;
-    fi;
-
+sum' mm =
+    case mm of
+    m1::m2::rest:
+        sum' 
+           (if (size m1) != (size m2)
+            then failWith "Matrices are not the same size: \(size m1), \(size m2)";
+            elif isSparse? m1 and isSparse? m2 then
+                sparseSumOrDifference (+) m1 m2;
+            else
+                add2 v1 v2 = vec.add [v1,v2];
+                denseLinearOp add2 m1 m2;
+            fi :: rest);
+    [m1]: m1;
+    _: failWith "Empty argument list";
+    esac;
+    
 difference m1 m2 =
     if (size m1) != (size m2)
     then failWith "Matrices are not the same size: \(size m1), \(size m2)";
@@ -1034,7 +1040,7 @@
     maxValue is matrix_t -> number,
     asRows is matrix_t -> list<vec.vector_t>, 
     asColumns is matrix_t -> list<vec.vector_t>,
-    sum is matrix_t -> matrix_t -> matrix_t,
+    sum is list?<matrix_t> -> matrix_t,
     difference is matrix_t -> matrix_t -> matrix_t,
     abs is matrix_t -> matrix_t,
     negative is matrix_t -> matrix_t,
--- a/src/may/matrix/complex.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/matrix/complex.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -204,7 +204,7 @@
     case p1 of
     Some m1:
         case p2 of
-        Some m2: Some (mat.sum m1 m2);
+        Some m2: Some (mat.sum [m1, m2]);
         none: Some m1;
         esac;
     none:
@@ -239,16 +239,22 @@
         none;
     esac;
 
-sum c1 c2 =
-   (a = c1.real;
-    b = c1.imaginary;
-    c = c2.real;
-    d = c2.imaginary;
-    {
-        size = c1.size,
-        real = addParts a c,
-        imaginary = addParts b d,
-    });
+sum cc =
+    case cc of
+    c1::c2::rest:
+        sum
+           (a = c1.real;
+            b = c1.imaginary;
+            c = c2.real;
+            d = c2.imaginary;
+            {
+                size = c1.size,
+                real = addParts a c,
+                imaginary = addParts b d,
+            } :: rest);
+    [c1]: c1;
+    _: failWith "Empty argument list";
+    esac;
 
 product c1 c2 =
    (a = c1.real;
@@ -374,7 +380,7 @@
     toSparse is complexmatrix_t -> complexmatrix_t,
     toDense is complexmatrix_t -> complexmatrix_t,
     scaled is number -> complexmatrix_t -> complexmatrix_t,
-    sum is complexmatrix_t -> complexmatrix_t -> complexmatrix_t,
+    sum is list?<complexmatrix_t> -> complexmatrix_t,
     difference is complexmatrix_t -> complexmatrix_t -> complexmatrix_t,
     abs is complexmatrix_t -> mat.matrix_t,
     product is complexmatrix_t -> complexmatrix_t -> complexmatrix_t,
--- a/src/may/matrix/test/speedtest.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/matrix/test/speedtest.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -127,30 +127,30 @@
 println "";
 
 print "CMD + CMD... ";
-reportOn (time \(mat.sum cmd cmd));
+reportOn (time \(mat.sum [cmd, cmd]));
 
 print "CMD + RMD... ";
-reportOn (time \(mat.sum cmd rmd));
+reportOn (time \(mat.sum [cmd, rmd]));
 
 print "RMD + CMD... ";
-reportOn (time \(mat.sum rmd cmd));
+reportOn (time \(mat.sum [rmd, cmd]));
 
 print "RMD + RMD... ";
-reportOn (time \(mat.sum rmd rmd));
+reportOn (time \(mat.sum [rmd, rmd]));
 
 println "";
 
 print "CMS + CMS... ";
-reportOn (time \(mat.sum cms cms));
+reportOn (time \(mat.sum [cms, cms]));
 
 print "CMS + RMS... ";
-reportOn (time \(mat.sum cms rms));
+reportOn (time \(mat.sum [cms, rms]));
 
 print "RMS + CMS... ";
-reportOn (time \(mat.sum rms cms));
+reportOn (time \(mat.sum [rms, cms]));
 
 print "RMS + RMS... ";
-reportOn (time \(mat.sum rms rms));
+reportOn (time \(mat.sum [rms, rms]));
 
 println "";
 
@@ -204,15 +204,15 @@
 println "";
 
 print "CMS + CMS... ";
-reportOn (time \(mat.sum cms cms));
+reportOn (time \(mat.sum [cms, cms]));
 
 print "CMS + RMS... ";
-reportOn (time \(mat.sum cms rms));
+reportOn (time \(mat.sum [cms, rms]));
 
 print "RMS + CMS... ";
-reportOn (time \(mat.sum rms cms));
+reportOn (time \(mat.sum [rms, cms]));
 
 print "RMS + RMS... ";
-reportOn (time \(mat.sum rms rms));
+reportOn (time \(mat.sum [rms, rms]));
 
 ();
--- a/src/may/matrix/test/test_matrix.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/matrix/test/test_matrix.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -229,15 +229,20 @@
 
 "sum-\(name)": \(
     compareMatrices
-       (mat.sum (constMatrix 2 { rows = 3, columns = 4 })
-                (constMatrix 1 { rows = 3, columns = 4 }))
-       (constMatrix 3 { rows = 3, columns = 4 })
+       (mat.sum [constMatrix 2 { rows = 3, columns = 4 },
+                 constMatrix 1 { rows = 3, columns = 4 }])
+       (constMatrix 3 { rows = 3, columns = 4 }) and
+    compareMatrices
+       (mat.sum [constMatrix 2 { rows = 3, columns = 4 },
+                 constMatrix 1 { rows = 3, columns = 4 },
+                 fromRows [[1,2,3,4],[5,6,7,8],[9,10,11,12]]])
+       (fromRows [[4,5,6,7],[8,9,10,11],[12,13,14,15]])
 ),
 
 "sumFail-\(name)": \(
     try 
-      \() (mat.sum (constMatrix 2 { rows = 3, columns = 4 })
-                   (constMatrix 1 { rows = 3, columns = 5 }));
+      \() (mat.sum [constMatrix 2 { rows = 3, columns = 4 },
+                    constMatrix 1 { rows = 3, columns = 5 }]);
         false;
     catch FailureException e:
         true
@@ -257,11 +262,11 @@
             { i = 1, j = 0, v = 5 },
             { i = 1, j = 1, v = -4 }, // NB this means [1,1] -> 0, sparse zero
         ]);
-    tot = mat.sum s t;
+    tot = mat.sum [s, t];
     mat.isSparse? tot and
-        compareMatrices tot (mat.sum (mat.toDense s) t) and
-        compareMatrices tot (mat.sum (mat.toDense s) (mat.toDense t)) and
-        compareMatrices tot (mat.sum s (mat.toDense t)) and
+        compareMatrices tot (mat.sum [mat.toDense s, t]) and
+        compareMatrices tot (mat.sum [mat.toDense s, mat.toDense t]) and
+        compareMatrices tot (mat.sum [s, mat.toDense t]) and
         compareMatrices tot 
            (mat.newSparseMatrix { rows = 2, columns = 3 } (Rows [
                { i = 0, j = 0, v = 1 },
--- a/src/may/stream/framer.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/stream/framer.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -163,9 +163,10 @@
         first::rest:
            (w = mat.width pending;
             pre = mat.columnSlice pending 0 (w - overlap);
-            added = mat.sum first
-               (mat.resizedTo (mat.size first)
-               (mat.columnSlice pending (w - overlap) w));
+            added = mat.sum
+               [first,
+                (mat.resizedTo (mat.size first)
+                    (mat.columnSlice pending (w - overlap) w))];
             ola rest added (pre::acc));
          _:
             reverse (pending::acc);
--- a/src/may/stream/manipulate.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/stream/manipulate.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -152,7 +152,7 @@
         if sz.columns == 0 then
             mat.zeroMatrix sz
         else
-            mat.sum (mat.resizedTo sz m1) (mat.resizedTo sz m2);
+            mat.sum [mat.resizedTo sz m1, mat.resizedTo sz m2];
         fi);
     channels = head (sortBy (>) (map (.channels) streams));
     {
--- a/src/may/vector.yeti	Mon Mar 24 11:23:28 2014 +0000
+++ b/src/may/vector.yeti	Mon Mar 24 11:50:05 2014 +0000
@@ -33,6 +33,10 @@
 /// Return a vector of length n, containing all ones.
 ones n = consts 1.0 n;
 
+/// Return a vector of length n, containing random values.
+randoms n is number -> ~double[] =
+   (map \(Math#random()) [1..n]) as ~double[];
+
 /// Return a vector of the values in the given list.
 fromList l is list?<number> -> ~double[] =
     l as ~double[];
@@ -302,6 +306,7 @@
     zeros,
     consts,
     ones,
+    randoms,
     vector v = v,
     primitive,
     floats,
@@ -343,6 +348,7 @@
     zeros is number -> vector_t,
     consts is number -> number -> vector_t,
     ones is number -> vector_t,
+    randoms is number -> vector_t,
     vector is ~double[] -> vector_t,
     primitive is vector_t -> ~double[],
     floats is vector_t -> ~float[],