comparison yetilab/matrix/matrix.yeti @ 223:51af10e6cd0d

Merge from matrix_opaque_immutable branch
author Chris Cannam
date Sat, 11 May 2013 16:00:58 +0100
parents 77c6a81c577f
children c00d8f7e2708
comparison
equal deleted inserted replaced
207:cd2caf235e1f 223:51af10e6cd0d
1 1
2 module yetilab.matrix.matrix; 2 module yetilab.matrix.matrix;
3 3
4 // A matrix is an array of fvectors (i.e. primitive double[]s). 4 // A matrix is an array of vectors.
5 5
6 // A matrix can be stored in either column-major (the default) or 6 // A matrix can be stored in either column-major (the default) or
7 // row-major format. Storage order is an efficiency concern only: 7 // row-major format. Storage order is an efficiency concern only:
8 // every API function operating on matrix objects will return the same 8 // every API function operating on matrix objects will return the same
9 // result regardless of storage order. (The transpose function just 9 // result regardless of storage order. (The transpose function just
10 // switches the row/column order without moving the elements.) 10 // switches the row/column order without moving the elements.)
11 11
12 vec = load yetilab.block.fvector; 12 //!!! check that we are not unnecessarily copying in the transform functions
13 block = load yetilab.block.block; 13
14 bf = load yetilab.block.blockfuncs; 14 vec = load yetilab.vector.vector;
15 15 bf = load yetilab.vector.blockfuncs;
16 load yetilab.block.blocktype; 16
17 load yetilab.vector.vectortype;
17 load yetilab.matrix.matrixtype; 18 load yetilab.matrix.matrixtype;
18 19
19 make d = { 20 size m =
20 get data () = d, 21 case m of
21 get size () = 22 RowM r:
22 case d of 23 major = length r;
23 RowM r: 24 {
24 major = length r; 25 rows = major,
25 { 26 columns = if major > 0 then vec.length r[0] else 0 fi,
26 rows = major, 27 };
27 columns = if major > 0 then vec.length r[0] else 0 fi, 28 ColM c:
28 }; 29 major = length c;
29 ColM c: 30 {
30 major = length c; 31 rows = if major > 0 then vec.length c[0] else 0 fi,
31 { 32 columns = major,
32 rows = if major > 0 then vec.length c[0] else 0 fi, 33 };
33 columns = major, 34 esac;
34 }; 35
35 esac, 36 width m = (size m).columns;
36 getColumn j = 37 height m = (size m).rows;
37 case d of 38
38 RowM rows: block.fromList (map do i: getAt i j done [0..length rows-1]); 39 getAt row col m =
39 ColM cols: block.block cols[j]; 40 case m of
40 esac, 41 RowM rows: r = rows[row]; vec.at col r;
41 getRow i = 42 ColM cols: c = cols[col]; vec.at row c;
42 case d of 43 esac;
43 RowM rows: block.block rows[i]; 44
44 ColM cols: block.fromList (map do j: getAt i j done [0..length cols-1]); 45 getColumn j m =
45 esac, 46 case m of
46 getAt row col = 47 RowM rows: vec.fromList (map do i: getAt i j m done [0..length rows-1]);
47 case d of 48 ColM cols: cols[j];
48 RowM rows: r = rows[row]; (r is ~double[])[col]; 49 esac;
49 ColM cols: c = cols[col]; (c is ~double[])[row]; 50
50 esac, 51 getRow i m =
51 setAt row col n = //!!! dangerous, could modify copies -- should it be allowed? 52 case m of
52 case d of 53 RowM rows: rows[i];
53 RowM rows: r = rows[row]; (r is ~double[])[col] := n; 54 ColM cols: vec.fromList (map do j: getAt i j m done [0..length cols-1]);
54 ColM cols: c = cols[col]; (c is ~double[])[row] := n; 55 esac;
55 esac, 56
56 get isRowMajor? () = 57 /*
57 case d of 58 setAt row col n m = //!!! dangerous, could modify copies -- should it be allowed?
58 RowM _: true; 59 case m of
59 ColM _: false; 60 RowM rows: r = rows[row]; (vec.data r)[col] := n;
60 esac, 61 ColM cols: c = cols[col]; (vec.data c)[row] := n;
61 }; 62 esac;
63 */
64
65 isRowMajor? m =
66 case m of
67 RowM _: true;
68 ColM _: false;
69 esac;
62 70
63 newColMajorStorage { rows, columns } = 71 newColMajorStorage { rows, columns } =
64 if rows < 1 then array [] 72 if rows < 1 then array []
65 else array (map \(vec.zeros rows) [1..columns]) 73 else array (map \(vec.zeros rows) [1..columns])
66 fi; 74 fi;
67 75
68 zeroMatrix { rows, columns } = 76 zeroMatrix { rows, columns } =
69 make (ColM (newColMajorStorage { rows, columns })); 77 ColM (newColMajorStorage { rows, columns });
70 78
71 zeroMatrixWithTypeOf m { rows, columns } = 79 zeroMatrixWithTypeOf m { rows, columns } =
72 if m.isRowMajor? then 80 if isRowMajor? m then
73 make (RowM (newColMajorStorage { rows = columns, columns = rows })); 81 RowM (newColMajorStorage { rows = columns, columns = rows });
74 else 82 else
75 make (ColM (newColMajorStorage { rows, columns })); 83 ColM (newColMajorStorage { rows, columns });
76 fi; 84 fi;
85
86 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
77 87
78 generate f { rows, columns } = 88 generate f { rows, columns } =
79 (m = newColMajorStorage { rows, columns }; 89 if rows < 1 or columns < 1 then zeroSizeMatrix ()
80 for [0..columns-1] do col: 90 else
81 for [0..rows-1] do row: 91 m = array (map \(new double[rows]) [1..columns]);
82 m[col][row] := f row col; 92 for [0..columns-1] do col:
93 for [0..rows-1] do row:
94 m[col][row] := f row col;
95 done;
83 done; 96 done;
84 done; 97 ColM (array (map vec.vector m))
85 make (ColM m)); 98 fi;
86 99
87 constMatrix n = generate do row col: n done; 100 constMatrix n = generate do row col: n done;
88 randomMatrix = generate do row col: Math#random() done; 101 randomMatrix = generate do row col: Math#random() done;
89 identityMatrix = constMatrix 1; 102 identityMatrix = constMatrix 1;
90 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
91
92 width m = m.size.columns;
93 height m = m.size.rows;
94 103
95 transposed m = 104 transposed m =
96 make 105 case m of
97 (case m.data of 106 RowM d: ColM d;
98 RowM d: ColM d; 107 ColM d: RowM d;
99 ColM d: RowM d; 108 esac;
100 esac);
101 109
102 flipped m = 110 flipped m =
103 if m.isRowMajor? then 111 if isRowMajor? m then
104 generate do row col: m.getAt row col done m.size; 112 generate do row col: getAt row col m done (size m);
105 else 113 else
106 transposed 114 transposed
107 (generate do row col: m.getAt col row done 115 (generate do row col: getAt col row m done
108 { rows = m.size.columns, columns = m.size.rows }); 116 { rows = (width m), columns = (height m) });
109 fi; 117 fi;
110 118
111 toRowMajor m = 119 toRowMajor m =
112 if m.isRowMajor? then m else flipped m fi; 120 if isRowMajor? m then m else flipped m fi;
113 121
114 toColumnMajor m = 122 toColumnMajor m =
115 if not m.isRowMajor? then m else flipped m fi; 123 if not isRowMajor? m then m else flipped m fi;
116 124
117 // Matrices with different storage order but the same contents are 125 // Matrices with different storage order but the same contents are
118 // equal (but comparing them is slow) 126 // equal (but comparing them is slow)
119 equal m1 m2 = 127 equal m1 m2 =
120 if m1.size != m2.size then false 128 if size m1 != size m2 then false
121 elif m1.isRowMajor? != m2.isRowMajor? then equal (flipped m1) m2; 129 elif isRowMajor? m1 != isRowMajor? m2 then equal (flipped m1) m2;
122 else 130 else
123 compare d1 d2 = all id (map2 vec.equal d1 d2); 131 compare d1 d2 = all id (map2 vec.equal d1 d2);
124 case m1.data of 132 case m1 of
125 RowM d1: case m2.data of RowM d2: compare d1 d2; _: false; esac; 133 RowM d1: case m2 of RowM d2: compare d1 d2; _: false; esac;
126 ColM d1: case m2.data of ColM d2: compare d1 d2; _: false; esac; 134 ColM d1: case m2 of ColM d2: compare d1 d2; _: false; esac;
127 esac 135 esac
128 fi; 136 fi;
129 137
138 /*!!! not needed now it's immutable?
130 copyOf m = 139 copyOf m =
131 (copyOfData d = (array (map vec.copyOf d)); 140 (copyOfData d = (array (map vec.copyOf d));
132 make 141 case m of
133 (case m.data of 142 RowM d: RowM (copyOfData d);
134 RowM d: RowM (copyOfData d); 143 ColM d: ColM (copyOfData d);
135 ColM d: ColM (copyOfData d); 144 esac);
136 esac)); 145 */
137 146
138 newMatrix type data = //!!! NB does not copy data 147 newMatrix type data = //!!! NB does not copy data
139 (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac; 148 (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac;
140 if empty? data or block.empty? (head data) 149 if empty? data or vec.empty? (head data)
141 then zeroSizeMatrix () 150 then zeroSizeMatrix ()
142 else make (tagger (array (map block.data data))) 151 else tagger (array data)
143 fi); 152 fi);
144 153
145 newRowVector data = //!!! NB does not copy data 154 newRowVector data = //!!! NB does not copy data
146 make (RowM (array [block.data data])); 155 RowM (array [data]);
147 156
148 newColumnVector data = //!!! NB does not copy data 157 newColumnVector data = //!!! NB does not copy data
149 make (ColM (array [block.data data])); 158 ColM (array [data]);
150 159
151 scaled factor m = //!!! v inefficient 160 scaled factor m = //!!! v inefficient
152 generate do row col: factor * m.getAt row col done m.size; 161 generate do row col: factor * (getAt row col m) done (size m);
153 162
154 sum' m1 m2 = 163 sum' m1 m2 =
155 if m1.size != m2.size 164 if (size m1) != (size m2)
156 then failWith "Matrices are not the same size: \(m1.size), \(m2.size)"; 165 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
157 else 166 else
158 generate do row col: m1.getAt row col + m2.getAt row col done m1.size; 167 generate do row col: getAt row col m1 + getAt row col m2 done (size m1);
159 fi; 168 fi;
160 169
161 product m1 m2 = 170 product m1 m2 =
162 if m1.size.columns != m2.size.rows 171 if (size m1).columns != (size m2).rows
163 then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)"; 172 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns != (size m2).rows)";
164 else 173 else
165 generate do row col: 174 generate do row col:
166 bf.sum (bf.multiply (m1.getRow row) (m2.getColumn col)) 175 bf.sum (bf.multiply (getRow row m1) (getColumn col m2))
167 done { rows = m1.size.rows, columns = m2.size.columns } 176 done { rows = (size m1).rows, columns = (size m2).columns }
168 fi; 177 fi;
169 178
170 asRows m = 179 asRows m =
171 map m.getRow [0 .. m.size.rows - 1]; 180 map do i: getRow i m done [0 .. (height m) - 1];
172 181
173 asColumns m = 182 asColumns m =
174 map m.getColumn [0 .. m.size.columns - 1]; 183 map do i: getColumn i m done [0 .. (width m) - 1];
175 184
176 concatAgainstGrain tagger getter counter mm = 185 concatAgainstGrain tagger getter counter mm =
177 (n = counter (head mm).size; 186 (n = counter (size (head mm));
178 make (tagger (array 187 tagger (array
179 (map do i: 188 (map do i:
180 block.data (block.concat (map do m: getter m i done mm)) 189 vec.concat (map (getter i) mm)
181 done [0..n-1])))); 190 done [0..n-1])));
182 191
183 concatWithGrain tagger getter counter mm = 192 concatWithGrain tagger getter counter mm =
184 make (tagger (array 193 tagger (array
185 (concat 194 (concat
186 (map do m: 195 (map do m:
187 n = counter m.size; 196 n = counter (size m);
188 map do i: block.data (getter m i) done [0..n-1] 197 map do i: getter i m done [0..n-1]
189 done mm)))); 198 done mm)));
190 199
191 checkDimensionsFor direction first mm = 200 checkDimensionsFor direction first mm =
192 (counter = if direction == Horizontal () then (.rows) else (.columns) fi; 201 (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
193 n = counter first.size; 202 n = counter (size first);
194 if not (all id (map do m: counter m.size == n done mm)) then 203 if not (all id (map do m: counter (size m) == n done mm)) then
195 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter m.size done mm) not all of which are \(n))"; 204 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
196 fi); 205 fi);
197 206
198 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence 207 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
199 //!!! would this be better as separate concatHorizontal/concatVertical functions? 208 //!!! would this be better as separate concatHorizontal/concatVertical functions?
200 case mm of 209 case mm of
201 first::rest: 210 first::rest:
202 checkDimensionsFor direction first mm; 211 checkDimensionsFor direction first mm;
203 row = first.isRowMajor?; 212 row = isRowMajor? first;
204 // horizontal, row-major: against grain with rows 213 // horizontal, row-major: against grain with rows
205 // horizontal, col-major: with grain with cols 214 // horizontal, col-major: with grain with cols
206 // vertical, row-major: with grain with rows 215 // vertical, row-major: with grain with rows
207 // vertical, col-major: against grain with cols 216 // vertical, col-major: against grain with cols
208 case direction of 217 case direction of
209 Horizontal (): 218 Horizontal ():
210 if row then concatAgainstGrain RowM (.getRow) (.rows) mm; 219 if row then concatAgainstGrain RowM getRow (.rows) mm;
211 else concatWithGrain ColM (.getColumn) (.columns) mm; 220 else concatWithGrain ColM getColumn (.columns) mm;
212 fi; 221 fi;
213 Vertical (): 222 Vertical ():
214 if row then concatWithGrain RowM (.getRow) (.rows) mm; 223 if row then concatWithGrain RowM getRow (.rows) mm;
215 else concatAgainstGrain ColM (.getColumn) (.columns) mm; 224 else concatAgainstGrain ColM getColumn (.columns) mm;
216 fi; 225 fi;
217 esac; 226 esac;
218 [single]: single; 227 [single]: single;
219 _: zeroSizeMatrix (); 228 _: zeroSizeMatrix ();
220 esac; 229 esac;
221 230
222 rowSlice start count m = //!!! doc: storage order same as input 231 rowSlice start count m = //!!! doc: storage order same as input
223 if m.isRowMajor? then 232 if isRowMajor? m then
224 make (RowM (array (map (block.data . m.getRow) [start .. start + count - 1]))) 233 RowM (array (map ((flip getRow) m) [start .. start + count - 1]))
225 else 234 else
226 make (ColM (array (map (block.data . (block.rangeOf start count)) (asColumns m)))) 235 ColM (array (map (vec.rangeOf start count) (asColumns m)))
227 fi; 236 fi;
228 237
229 columnSlice start count m = //!!! doc: storage order same as input 238 columnSlice start count m = //!!! doc: storage order same as input
230 if not m.isRowMajor? then 239 if not isRowMajor? m then
231 make (ColM (array (map (block.data . m.getColumn) [start .. start + count - 1]))) 240 ColM (array (map ((flip getColumn) m) [start .. start + count - 1]))
232 else 241 else
233 make (RowM (array (map (block.data . (block.rangeOf start count)) (asRows m)))) 242 RowM (array (map (vec.rangeOf start count) (asRows m)))
234 fi; 243 fi;
235 244
236 resizedTo newsize m = 245 resizedTo newsize m =
237 (if newsize == m.size then 246 (if newsize == (size m) then
238 m 247 m
239 elif m.size.rows == 0 or m.size.columns == 0 then 248 elif (height m) == 0 or (width m) == 0 then
240 zeroMatrixWithTypeOf m newsize; 249 zeroMatrixWithTypeOf m newsize;
241 else 250 else
242 growrows = newsize.rows - m.size.rows; 251 growrows = newsize.rows - (height m);
243 growcols = newsize.columns - m.size.columns; 252 growcols = newsize.columns - (width m);
244 rowm = m.isRowMajor?; 253 rowm = isRowMajor? m;
245 resizedTo newsize 254 resizedTo newsize
246 if rowm and growrows < 0 then 255 if rowm and growrows < 0 then
247 rowSlice 0 newsize.rows m 256 rowSlice 0 newsize.rows m
248 elif (not rowm) and growcols < 0 then 257 elif (not rowm) and growcols < 0 then
249 columnSlice 0 newsize.columns m 258 columnSlice 0 newsize.columns m
252 elif growcols < 0 then 261 elif growcols < 0 then
253 columnSlice 0 newsize.columns m 262 columnSlice 0 newsize.columns m
254 else 263 else
255 if growrows > 0 then 264 if growrows > 0 then
256 concat (Vertical ()) 265 concat (Vertical ())
257 [m, zeroMatrixWithTypeOf m (m.size with { rows = growrows })] 266 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
258 else 267 else
259 concat (Horizontal ()) 268 concat (Horizontal ())
260 [m, zeroMatrixWithTypeOf m (m.size with { columns = growcols })] 269 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
261 fi 270 fi
262 fi 271 fi
263 fi); 272 fi);
264 273
265 { 274 {
275 size,
276 width,
277 height,
278 getAt,
279 getColumn,
280 getRow,
281 // setAt,
282 isRowMajor?,
283 generate,
284 constMatrix,
285 randomMatrix,
286 zeroMatrix,
287 identityMatrix,
288 zeroSizeMatrix,
289 equal,
290 // copyOf,
291 transposed,
292 flipped,
293 toRowMajor,
294 toColumnMajor,
295 scaled,
296 resizedTo,
297 asRows,
298 asColumns,
299 sum = sum',
300 product,
301 concat,
302 rowSlice,
303 columnSlice,
304 newMatrix,
305 newRowVector,
306 newColumnVector,
307 }
308 as
309 {
310 //!!! check whether these are right to be .selector rather than just selector
311
312 size is matrix -> { .rows is number, .columns is number },
313 width is matrix -> number,
314 height is matrix -> number,
315 getAt is number -> number -> matrix -> number,
316 getColumn is number -> matrix -> vector,
317 getRow is number -> matrix -> vector,
318 // setAt is number -> number -> number -> matrix -> (), //!!! lose?
319 isRowMajor? is matrix -> boolean,
266 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix, 320 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
267 constMatrix is number -> { .rows is number, .columns is number } -> matrix, 321 constMatrix is number -> { .rows is number, .columns is number } -> matrix,
268 randomMatrix is { .rows is number, .columns is number } -> matrix, 322 randomMatrix is { .rows is number, .columns is number } -> matrix,
269 zeroMatrix is { .rows is number, .columns is number } -> matrix, 323 zeroMatrix is { .rows is number, .columns is number } -> matrix,
270 identityMatrix is { .rows is number, .columns is number } -> matrix, 324 identityMatrix is { .rows is number, .columns is number } -> matrix,
271 zeroSizeMatrix is () -> matrix, 325 zeroSizeMatrix is () -> matrix,
272 width is matrix -> number,
273 height is matrix -> number,
274 equal is matrix -> matrix -> boolean, 326 equal is matrix -> matrix -> boolean,
275 copyOf is matrix -> matrix, 327 // copyOf is matrix -> matrix,
276 transposed is matrix -> matrix, 328 transposed is matrix -> matrix,
277 flipped is matrix -> matrix, 329 flipped is matrix -> matrix,
278 toRowMajor is matrix -> matrix, 330 toRowMajor is matrix -> matrix,
279 toColumnMajor is matrix -> matrix, 331 toColumnMajor is matrix -> matrix,
280 scaled is number -> matrix -> matrix, 332 scaled is number -> matrix -> matrix,
281 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix, 333 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
282 asRows is matrix -> list<block>, 334 asRows is matrix -> list<vector>,
283 asColumns is matrix -> list<block>, 335 asColumns is matrix -> list<vector>,
284 sum is matrix -> matrix -> matrix = sum', 336 sum is matrix -> matrix -> matrix,
285 product is matrix -> matrix -> matrix, 337 product is matrix -> matrix -> matrix,
286 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix, 338 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
287 rowSlice is number -> number -> matrix -> matrix, 339 rowSlice is number -> number -> matrix -> matrix,
288 columnSlice is number -> number -> matrix -> matrix, 340 columnSlice is number -> number -> matrix -> matrix,
289 newMatrix is (ColumnMajor () | RowMajor ()) -> list<block> -> matrix, 341 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix,
290 newRowVector is block -> matrix, 342 newRowVector is vector -> matrix,
291 newColumnVector is block -> matrix, 343 newColumnVector is vector -> matrix,
292 } 344 }
293 345