changeset 554:3fdffd2d0649

Make vector arithmetic functions throw if args are of differing lengths -- definitely makes for a higher probability of getting correct code. Also matches the matrix code, which expects args to be of the same size
author Chris Cannam
date Mon, 31 Mar 2014 14:31:56 +0100
parents 25b925cf3c98
children 30799a778cac
files src/may/bits/VectorBits.java src/may/stream/resample.yeti src/may/test.yeti src/may/vector.yeti src/may/vector/test/test_vector.yeti
diffstat 5 files changed, 42 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/src/may/bits/VectorBits.java	Mon Mar 31 10:28:42 2014 +0100
+++ b/src/may/bits/VectorBits.java	Mon Mar 31 14:31:56 2014 +0100
@@ -3,6 +3,15 @@
 
 public class VectorBits
 {
+    public static void checkLengths(double[] v1, double[] v2) {
+	if (v1.length != v2.length) {
+	    throw new IllegalArgumentException
+		("Found vector of length " + v2.length +
+		 ", but all so far in this arithmetic operation have had length " +
+		 v1.length);
+	}
+    }
+
     public static double sum(double[] v) {
 	double tot = 0.0;
 	int len = v.length;
@@ -13,6 +22,7 @@
     }
 
     public static void multiplyBy(double[] out, double[] in) {
+	checkLengths(out, in);
 	for (int i = 0; i < in.length && i < out.length; ++i) {
 	    out[i] *= in[i];
 	}
@@ -22,18 +32,21 @@
     }
 
     public static void divideBy(double[] out, double[] in) {
+	checkLengths(out, in);
 	for (int i = 0; i < in.length && i < out.length; ++i) {
 	    out[i] /= in[i];
 	}
     }
 
     public static void addTo(double[] out, double[] in) {
+	checkLengths(out, in);
 	for (int i = 0; i < in.length && i < out.length; ++i) {
 	    out[i] += in[i];
 	}
     }
 
     public static void subtractFrom(double[] out, double[] in) {
+	checkLengths(out, in);
 	for (int i = 0; i < in.length && i < out.length; ++i) {
 	    out[i] -= in[i];
 	}
--- a/src/may/stream/resample.yeti	Mon Mar 31 10:28:42 2014 +0100
+++ b/src/may/stream/resample.yeti	Mon Mar 31 14:31:56 2014 +0100
@@ -297,6 +297,7 @@
                 pd = phaseData[phase];
 
                 for [0..s.channels-1] do ch:
+//                println "lengths \(vec.length (mat.getRow ch buffer)) and \(vec.length pd.filter)";
                     rowdata[ch][i] := vec.sum
                        (vec.multiply [mat.getRow ch buffer, pd.filter]);
                 done;
--- a/src/may/test.yeti	Mon Mar 31 10:28:42 2014 +0100
+++ b/src/may/test.yeti	Mon Mar 31 14:31:56 2014 +0100
@@ -44,6 +44,16 @@
         false;
     fi);
 
+assertException f =
+    try
+        \() (f ());
+        println "** failed to catch expected exception";
+        false;
+    catch Exception _:
+        goodCompares := goodCompares + 1;
+        true;
+    yrt;
+
 time msg f =
    (start = System#currentTimeMillis();
     result = f ();
@@ -62,7 +72,7 @@
                     println "Test \(name) failed";
                     name;
                 fi 
-            catch FailureException e:
+            catch Exception e:
                 println "Test \(name) threw exception: \(e)";
                 name;
             yrt;
@@ -81,7 +91,7 @@
     bad);
 
 {
-    compare, compareUsing, compareMatrices, assert,
+    compare, compareUsing, compareMatrices, assert, assertException,
     time,
     runTests, 
 }
--- a/src/may/vector.yeti	Mon Mar 31 10:28:42 2014 +0100
+++ b/src/may/vector.yeti	Mon Mar 31 14:31:56 2014 +0100
@@ -218,19 +218,19 @@
     _: failWith "Empty argument list";
     esac;
 
-//!!! doc: returned vector is same length as first argument (this has changed, formerly it was length of shortest argument)
+//!!! doc: throws exception if not all inputs are of the same length
 add vv is list?<~double[]> -> ~double[] =
     listOp do out v: VectorBits#addTo(out, v) done vv;
 
-//!!! doc: returned vector is same length as first argument (this has changed, formerly it was length of shortest argument)
+//!!! doc: throws exception if not all inputs are of the same length
 subtract v1 v2 is ~double[] -> ~double[] -> ~double[] =
     listOp do out v: VectorBits#subtractFrom(out, v) done [v1, v2];
 
-//!!! doc: returned vector is same length as first argument (this has changed, formerly it was length of shortest argument). If first arg is longer than others, the spare values will become zero
+//!!! doc: throws exception if not all inputs are of the same length
 multiply vv is list?<~double[]> -> ~double[] =
     listOp do out v: VectorBits#multiplyBy(out, v) done vv;
 
-//!!! doc: returned vector is same length as first argument (this has changed, formerly it was length of shortest argument). If first arg is longer than second, the spare values are left untouched (not divided by zero)
+//!!! doc: throws exception if not all inputs are of the same length
 divide v1 v2 is ~double[] -> ~double[] -> ~double[] = 
     listOp do out v: VectorBits#divideBy(out, v) done [v1, v2];
 
--- a/src/may/vector/test/test_vector.yeti	Mon Mar 31 10:28:42 2014 +0100
+++ b/src/may/vector/test/test_vector.yeti	Mon Mar 31 14:31:56 2014 +0100
@@ -3,7 +3,7 @@
 
 vec = load may.vector;
 
-{ compare } = load may.test;
+{ compare, assertException } = load may.test;
 
 [
 
@@ -167,28 +167,27 @@
 ),
 
 "add": \(
-    compare (vec.list (vec.add [vec.zeros 0, vec.ones 5])) [] and
-        compare (vec.list (vec.add [vec.consts 3 4, vec.fromList [1,2,3] ])) [4,5,6,3] and
-        compare (vec.list (vec.add [vec.consts (-3) 4, vec.fromList [1,2,3] ])) [-2,-1,0,-3] and
-        compare (vec.list (vec.add [vec.consts 3 3, vec.fromList [1,2,3], vec.fromList [6,7,8,9] ])) [10,12,14] 
+    compare (vec.list (vec.add [vec.zeros 0, vec.ones 0])) [] and
+        compare (vec.list (vec.add [vec.consts 3 4, vec.fromList [1,2,3,0] ])) [4,5,6,3] and
+        assertException \(vec.add [vec.consts 3 3, vec.fromList [1,2,3], vec.fromList [6,7,8,9] ])
 ),
 
 "subtract": \(
-    compare (vec.list (vec.subtract (vec.zeros 0) (vec.ones 5))) [] and
-        compare (vec.list (vec.subtract (vec.consts 3 4) (vec.fromList [1,2,3]))) [2,1,0,3] and
-        compare (vec.list (vec.subtract (vec.consts (-3) 4) (vec.fromList [1,2,3]))) [-4,-5,-6,-3]
+    compare (vec.list (vec.subtract (vec.zeros 0) (vec.ones 0))) [] and
+        compare (vec.list (vec.subtract (vec.consts 3 4) (vec.fromList [1,2,3,0]))) [2,1,0,3] and
+        assertException \(vec.subtract (vec.consts (-3) 4) (vec.fromList [1,2,3]))
 ),
 
 "multiply": \(
-    compare (vec.list (vec.multiply [vec.zeros 0, vec.ones 5])) [] and
-        compare (vec.list (vec.multiply [vec.consts (-3) 4, vec.fromList [1,2,3]])) [-3,-6,-9,0] and
-        compare (vec.list (vec.multiply [vec.consts 3 3, vec.fromList [1,2,3], vec.fromList [6,7,8,9]])) [18,42,72]
+    compare (vec.list (vec.multiply [vec.zeros 0, vec.ones 0])) [] and
+        compare (vec.list (vec.multiply [vec.consts (-3) 4, vec.fromList [1,2,3,0]])) [-3,-6,-9,0] and
+        assertException \(vec.multiply [vec.consts 3 3, vec.fromList [1,2,3], vec.fromList [6,7,8,9]])
 ),
 
 "divide": \(
-    compare (vec.list (vec.divide (vec.zeros 0) (vec.ones 5))) [] and
+    compare (vec.list (vec.divide (vec.zeros 0) (vec.ones 0))) [] and
         compare (vec.list (vec.divide (vec.consts (-3) 3) (vec.fromList [1,2,3]))) [-3,-(3/2),-1] and
-        compare (vec.list (vec.divide (vec.consts (-3) 4) (vec.fromList [1,2,3]))) [-3,-(3/2),-1,-3] // values in first vec beyond end of second are *not* divided by zero, just left alone
+        assertException \(vec.divide (vec.consts (-3) 4) (vec.fromList [1,2,3]));
 ),
 
 "scaled": \(