changeset 545:01863795221c

Entry-wise divide
author Chris Cannam
date Fri, 21 Mar 2014 17:32:48 +0000
parents 8112db99ab50
children 17d5a8986f6f
files src/may/bits/VectorBits.java src/may/matrix.yeti src/may/matrix/test/test_matrix.yeti src/may/vector.yeti src/may/vector/test/test_vector.yeti
diffstat 5 files changed, 63 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/src/may/bits/VectorBits.java	Fri Mar 21 16:37:28 2014 +0000
+++ b/src/may/bits/VectorBits.java	Fri Mar 21 17:32:48 2014 +0000
@@ -24,6 +24,18 @@
 	return out;
     }
 
+    public static double[] divide(double[] v1, double[] v2) {
+	int len = v1.length;
+	if (v2.length < len) {
+	    len = v2.length;
+	}
+	double[] out = new double[len];
+	for (int i = 0; i < len; ++i) {
+	    out[i] = v1[i] / v2[i];
+	}
+	return out;
+    }
+
     public static void addTo(double[] out, double[] in, int len) {
 	for (int i = 0; i < len; ++i) {
 	    out[i] += in[i];
--- a/src/may/matrix.yeti	Fri Mar 21 16:37:28 2014 +0000
+++ b/src/may/matrix.yeti	Fri Mar 21 17:32:48 2014 +0000
@@ -676,7 +676,7 @@
         fi;
     fi;
 
-entryWiseProduct m1 m2 = // or element-wise, or Hadamard product
+entryWiseProduct m1 m2 =
     if (size m1) != (size m2)
     then failWith "Matrices are not the same size: \(size m1), \(size m2)";
     else 
@@ -696,6 +696,26 @@
         fi
     fi;
 
+entryWiseDivide m1 m2 =
+    if (size m1) != (size m2)
+    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
+    else 
+        if isSparse? m1 then
+            newSparse (size m1)
+               ((taggerForTypeOf m1)
+                   (map do { i, j, v }: { i, j, v = v / (at' m2 i j) } done
+                       (enumerateSparse m1)))
+        // For m2 to be sparse makes no sense (divide by zero all over
+        // the shop).
+        else
+            if isRowMajor? m1 then
+                fromRows (array (map2 vec.divide (asRows m1) (asRows m2)));
+            else
+                fromColumns (array (map2 vec.divide (asColumns m1) (asColumns m2)));
+            fi
+        fi
+    fi;
+
 concatAgainstGrain tagger getter counter mm =
    (n = counter (size (head mm));
     tagger (array
@@ -951,6 +971,7 @@
     any = any',
     product,
     entryWiseProduct,
+    entryWiseDivide,
     concatHorizontal,
     concatVertical,
     rowSlice,
@@ -1011,6 +1032,7 @@
     any is (number -> boolean) -> matrix_t -> boolean,
     product is matrix_t -> matrix_t -> matrix_t,
     entryWiseProduct is matrix_t -> matrix_t -> matrix_t,
+    entryWiseDivide is matrix_t -> matrix_t -> matrix_t,
     concatHorizontal is list<matrix_t> -> matrix_t,
     concatVertical is list<matrix_t> -> matrix_t,
     rowSlice is matrix_t -> number -> number -> matrix_t, 
--- a/src/may/matrix/test/test_matrix.yeti	Fri Mar 21 16:37:28 2014 +0000
+++ b/src/may/matrix/test/test_matrix.yeti	Fri Mar 21 17:32:48 2014 +0000
@@ -345,6 +345,19 @@
        (fromRows [[6,14,24],[0,5,0]])
 ),
 
+"entryWiseDivide-\(name)": \(
+    compareMatrices
+       (mat.entryWiseDivide
+           (fromRows [[1,2,3],[4,5,0]])
+           (fromRows [[6,7,8],[1,2,3]]))
+       (fromRows [[1/6,2/7,3/8],[4,5/2,0]]) and
+    compareMatrices
+       (mat.entryWiseDivide
+           (fromRows [[1,2,3],[4,5,0]])
+           (fromColumns [[6,1],[7,2],[8,3]]))
+       (fromRows [[1/6,2/7,3/8],[4,5/2,0]]);
+),
+
 "sparseProduct-\(name)": \(
     s = mat.newSparseMatrix { rows = 2, columns = 3 } (Columns [
         { i = 0, j = 0, v = 1 },
--- a/src/may/vector.yeti	Fri Mar 21 16:37:28 2014 +0000
+++ b/src/may/vector.yeti	Fri Mar 21 17:32:48 2014 +0000
@@ -223,6 +223,9 @@
 multiply b1 b2 is ~double[] -> ~double[] -> ~double[] = 
     VectorBits#multiply(b1, b2);
 
+divide b1 b2 is ~double[] -> ~double[] -> ~double[] = 
+    VectorBits#divide(b1, b2);
+
 scaled n v is number -> ~double[] -> ~double[] =
     if n == 1 then v
     else VectorBits#scaled(v, n);
@@ -315,8 +318,9 @@
     add,
     subtract,
     multiply,
+    divide,
+    scaled,
     divideBy,
-    scaled,
     abs = abs',
     negative,
     sqr,
@@ -355,8 +359,9 @@
     add is list?<vector_t> -> vector_t,
     subtract is vector_t -> vector_t -> vector_t,
     multiply is vector_t -> vector_t -> vector_t, 
+    divide is vector_t -> vector_t -> vector_t, 
+    scaled is number -> vector_t -> vector_t,
     divideBy is number -> vector_t -> vector_t, 
-    scaled is number -> vector_t -> vector_t,
     abs is vector_t -> vector_t,
     negative is vector_t -> vector_t,
     sqr is vector_t -> vector_t,
--- a/src/may/vector/test/test_vector.yeti	Fri Mar 21 16:37:28 2014 +0000
+++ b/src/may/vector/test/test_vector.yeti	Fri Mar 21 17:32:48 2014 +0000
@@ -183,9 +183,9 @@
         compare (vec.list (vec.multiply (vec.consts (-3) 4) (vec.fromList [1,2,3]))) [-3,-6,-9]
 ),
 
-"divideBy": \(
-    compare (vec.list (vec.divideBy 5 (vec.ones 0))) [] and
-        compare (vec.list (vec.divideBy 5 (vec.fromList [1,2,-3]))) [0.2,0.4,-0.6]
+"divide": \(
+    compare (vec.list (vec.divide (vec.zeros 0) (vec.ones 5))) [] and
+        compare (vec.list (vec.divide (vec.consts (-3) 4) (vec.fromList [1,2,3]))) [-3,-(3/2),-1]
 ),
 
 "scaled": \(
@@ -193,6 +193,11 @@
         compare (vec.list (vec.scaled 5 (vec.fromList [1,2,-3]))) [5,10,-15]
 ),
 
+"divideBy": \(
+    compare (vec.list (vec.divideBy 5 (vec.ones 0))) [] and
+        compare (vec.list (vec.divideBy 5 (vec.fromList [1,2,-3]))) [0.2,0.4,-0.6]
+),
+
 "abs": \(
     compare (vec.list (vec.abs (vec.ones 0))) [] and
         compare (vec.list (vec.abs (vec.fromList [1,2,-3]))) [1,2,3]