annotate yetilab/matrix/matrix.yeti @ 159:a9d58d9c71ca

Fix matrix.equal (did not check matrix size, could return true if one was a subset of the other) and add test for it; fix resizedTo test
author Chris Cannam
date Wed, 01 May 2013 12:32:08 +0100
parents b6db07468ed1
children 38938ca5db0c
rev   line source
Chris@5 1
Chris@94 2 module yetilab.matrix.matrix;
Chris@94 3
Chris@94 4 // A matrix is an array of fvectors (i.e. primitive double[]s).
Chris@94 5
Chris@101 6 // A matrix can be stored in either column-major (the default) or
Chris@101 7 // row-major format. Storage order is an efficiency concern only:
Chris@101 8 // every API function operating on matrix objects will return the same
Chris@101 9 // result regardless of storage order. (The transpose function just
Chris@101 10 // switches the row/column order without moving the elements.)
Chris@18 11
Chris@93 12 vec = load yetilab.block.fvector;
Chris@96 13 block = load yetilab.block.block;
Chris@99 14 bf = load yetilab.block.blockfuncs;
Chris@9 15
Chris@96 16 make d = {
Chris@96 17 get data () = d,
Chris@96 18 get size () =
Chris@96 19 case d of
Chris@96 20 RowM r:
Chris@96 21 major = length r;
Chris@96 22 {
Chris@96 23 rows = major,
Chris@96 24 columns = if major > 0 then vec.length r[0] else 0 fi,
Chris@96 25 };
Chris@96 26 ColM c:
Chris@96 27 major = length c;
Chris@96 28 {
Chris@96 29 rows = if major > 0 then vec.length c[0] else 0 fi,
Chris@96 30 columns = major,
Chris@96 31 };
Chris@94 32 esac,
Chris@96 33 getColumn j =
Chris@96 34 case d of
Chris@96 35 RowM rows: block.fromList (map do i: getAt i j done [0..length rows-1]);
Chris@96 36 ColM cols: block.block cols[j];
Chris@94 37 esac,
Chris@96 38 getRow i =
Chris@96 39 case d of
Chris@96 40 RowM rows: block.block rows[i];
Chris@96 41 ColM cols: block.fromList (map do j: getAt i j done [0..length cols-1]);
Chris@94 42 esac,
Chris@96 43 getAt row col =
Chris@96 44 case d of
Chris@96 45 RowM rows: r = rows[row]; (r is ~double[])[col];
Chris@96 46 ColM cols: c = cols[col]; (c is ~double[])[row];
Chris@94 47 esac,
Chris@158 48 setAt row col n = //!!! dangerous, could modify copies -- should it be allowed?
Chris@96 49 case d of
Chris@96 50 RowM rows: r = rows[row]; (r is ~double[])[col] := n;
Chris@96 51 ColM cols: c = cols[col]; (c is ~double[])[row] := n;
Chris@95 52 esac,
Chris@95 53 get isRowMajor? () =
Chris@96 54 case d of
Chris@96 55 RowM _: true;
Chris@96 56 ColM _: false;
Chris@95 57 esac,
Chris@94 58 };
Chris@94 59
Chris@98 60 newStorage { rows, columns } =
Chris@97 61 if rows < 1 then array []
Chris@98 62 else array (map \(vec.zeros rows) [1..columns])
Chris@97 63 fi;
Chris@94 64
Chris@98 65 zeroMatrix { rows, columns } =
Chris@98 66 make (ColM (newStorage { rows, columns }));
Chris@5 67
Chris@98 68 generate f { rows, columns } =
Chris@98 69 (m = newStorage { rows, columns };
Chris@98 70 for [0..columns-1] do col:
Chris@94 71 for [0..rows-1] do row:
Chris@94 72 m[col][row] := f row col;
Chris@5 73 done;
Chris@5 74 done;
Chris@96 75 make (ColM m));
Chris@5 76
Chris@20 77 constMatrix n = generate do row col: n done;
Chris@20 78 randomMatrix = generate do row col: Math#random() done;
Chris@5 79 identityMatrix = constMatrix 1;
Chris@158 80 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
Chris@5 81
Chris@96 82 width m = m.size.columns;
Chris@96 83 height m = m.size.rows;
Chris@6 84
Chris@100 85 transposed m =
Chris@100 86 make
Chris@100 87 (case m.data of
Chris@100 88 RowM d: ColM d;
Chris@100 89 ColM d: RowM d;
Chris@100 90 esac);
Chris@100 91
Chris@100 92 flipped m =
Chris@100 93 if m.isRowMajor? then
Chris@100 94 generate do row col: m.getAt row col done m.size;
Chris@100 95 else
Chris@100 96 transposed
Chris@100 97 (generate do row col: m.getAt col row done
Chris@100 98 { rows = m.size.columns, columns = m.size.rows });
Chris@100 99 fi;
Chris@100 100
Chris@100 101 // Matrices with different storage order but the same contents are
Chris@100 102 // equal (but comparing them is slow)
Chris@97 103 equal m1 m2 =
Chris@159 104 if m1.size != m2.size then false
Chris@159 105 elif m1.isRowMajor? != m2.isRowMajor? then equal (flipped m1) m2;
Chris@100 106 else
Chris@100 107 compare d1 d2 = all id (map2 vec.equal d1 d2);
Chris@100 108 case m1.data of
Chris@100 109 RowM d1: case m2.data of RowM d2: compare d1 d2; _: false; esac;
Chris@100 110 ColM d1: case m2.data of ColM d2: compare d1 d2; _: false; esac;
Chris@100 111 esac
Chris@100 112 fi;
Chris@97 113
Chris@95 114 copyOf m =
Chris@95 115 (copyOfData d = (array (map vec.copyOf d));
Chris@96 116 make
Chris@95 117 (case m.data of
Chris@96 118 RowM d: RowM (copyOfData d);
Chris@96 119 ColM d: ColM (copyOfData d);
Chris@95 120 esac));
Chris@6 121
Chris@96 122 newMatrix type data is RowMajor () | ColumnMajor () -> list?<list?<number>> -> 'a =
Chris@96 123 (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac;
Chris@97 124 if empty? data or empty? (head data)
Chris@98 125 then zeroMatrix { rows = 0, columns = 0 }
Chris@96 126 else make (tagger (array (map vec.vector data)))
Chris@96 127 fi);
Chris@96 128
Chris@96 129 newRowVector data =
Chris@96 130 newMatrix (RowMajor ()) [data];
Chris@96 131
Chris@96 132 newColumnVector data =
Chris@96 133 newMatrix (ColumnMajor ()) [data];
Chris@8 134
Chris@98 135 scaled factor m =
Chris@98 136 generate do row col: factor * m.getAt row col done m.size;
Chris@98 137
Chris@158 138 resizedTo newsize m =
Chris@158 139 (oldsize = m.size;
Chris@158 140 if newsize == oldsize then m
Chris@158 141 else
Chris@158 142 generate do row col:
Chris@158 143 if row < oldsize.rows and col < oldsize.columns
Chris@158 144 then m.getAt row col else 0 fi
Chris@158 145 done newsize;
Chris@158 146 fi);
Chris@158 147
Chris@98 148 sum' m1 m2 =
Chris@98 149 if m1.size != m2.size
Chris@98 150 then failWith "Matrices are not the same size: \(m1.size), \(m2.size)";
Chris@98 151 else
Chris@98 152 generate do row col: m1.getAt row col + m2.getAt row col done m1.size;
Chris@98 153 fi;
Chris@98 154
Chris@98 155 product m1 m2 =
Chris@98 156 if m1.size.columns != m2.size.rows
Chris@98 157 then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)";
Chris@98 158 else
Chris@98 159 generate do row col:
Chris@99 160 bf.sum (bf.multiply (m1.getRow row) (m2.getColumn col))
Chris@98 161 done { rows = m1.size.rows, columns = m2.size.columns }
Chris@98 162 fi;
Chris@98 163
Chris@5 164 {
Chris@158 165 constMatrix, randomMatrix, zeroMatrix, identityMatrix, zeroSizeMatrix,
Chris@96 166 generate,
Chris@96 167 width, height,
Chris@97 168 equal,
Chris@20 169 copyOf,
Chris@15 170 transposed,
Chris@96 171 flipped,
Chris@98 172 scaled,
Chris@158 173 resizedTo,
Chris@98 174 sum = sum', product,
Chris@98 175 newMatrix, newRowVector, newColumnVector,
Chris@5 176 }
Chris@5 177