Mercurial > hg > may
comparison yetilab/matrix/matrix.yeti @ 223:51af10e6cd0d
Merge from matrix_opaque_immutable branch
author | Chris Cannam |
---|---|
date | Sat, 11 May 2013 16:00:58 +0100 |
parents | 77c6a81c577f |
children | c00d8f7e2708 |
comparison
equal
deleted
inserted
replaced
207:cd2caf235e1f | 223:51af10e6cd0d |
---|---|
1 | 1 |
2 module yetilab.matrix.matrix; | 2 module yetilab.matrix.matrix; |
3 | 3 |
4 // A matrix is an array of fvectors (i.e. primitive double[]s). | 4 // A matrix is an array of vectors. |
5 | 5 |
6 // A matrix can be stored in either column-major (the default) or | 6 // A matrix can be stored in either column-major (the default) or |
7 // row-major format. Storage order is an efficiency concern only: | 7 // row-major format. Storage order is an efficiency concern only: |
8 // every API function operating on matrix objects will return the same | 8 // every API function operating on matrix objects will return the same |
9 // result regardless of storage order. (The transpose function just | 9 // result regardless of storage order. (The transpose function just |
10 // switches the row/column order without moving the elements.) | 10 // switches the row/column order without moving the elements.) |
11 | 11 |
12 vec = load yetilab.block.fvector; | 12 //!!! check that we are not unnecessarily copying in the transform functions |
13 block = load yetilab.block.block; | 13 |
14 bf = load yetilab.block.blockfuncs; | 14 vec = load yetilab.vector.vector; |
15 | 15 bf = load yetilab.vector.blockfuncs; |
16 load yetilab.block.blocktype; | 16 |
17 load yetilab.vector.vectortype; | |
17 load yetilab.matrix.matrixtype; | 18 load yetilab.matrix.matrixtype; |
18 | 19 |
19 make d = { | 20 size m = |
20 get data () = d, | 21 case m of |
21 get size () = | 22 RowM r: |
22 case d of | 23 major = length r; |
23 RowM r: | 24 { |
24 major = length r; | 25 rows = major, |
25 { | 26 columns = if major > 0 then vec.length r[0] else 0 fi, |
26 rows = major, | 27 }; |
27 columns = if major > 0 then vec.length r[0] else 0 fi, | 28 ColM c: |
28 }; | 29 major = length c; |
29 ColM c: | 30 { |
30 major = length c; | 31 rows = if major > 0 then vec.length c[0] else 0 fi, |
31 { | 32 columns = major, |
32 rows = if major > 0 then vec.length c[0] else 0 fi, | 33 }; |
33 columns = major, | 34 esac; |
34 }; | 35 |
35 esac, | 36 width m = (size m).columns; |
36 getColumn j = | 37 height m = (size m).rows; |
37 case d of | 38 |
38 RowM rows: block.fromList (map do i: getAt i j done [0..length rows-1]); | 39 getAt row col m = |
39 ColM cols: block.block cols[j]; | 40 case m of |
40 esac, | 41 RowM rows: r = rows[row]; vec.at col r; |
41 getRow i = | 42 ColM cols: c = cols[col]; vec.at row c; |
42 case d of | 43 esac; |
43 RowM rows: block.block rows[i]; | 44 |
44 ColM cols: block.fromList (map do j: getAt i j done [0..length cols-1]); | 45 getColumn j m = |
45 esac, | 46 case m of |
46 getAt row col = | 47 RowM rows: vec.fromList (map do i: getAt i j m done [0..length rows-1]); |
47 case d of | 48 ColM cols: cols[j]; |
48 RowM rows: r = rows[row]; (r is ~double[])[col]; | 49 esac; |
49 ColM cols: c = cols[col]; (c is ~double[])[row]; | 50 |
50 esac, | 51 getRow i m = |
51 setAt row col n = //!!! dangerous, could modify copies -- should it be allowed? | 52 case m of |
52 case d of | 53 RowM rows: rows[i]; |
53 RowM rows: r = rows[row]; (r is ~double[])[col] := n; | 54 ColM cols: vec.fromList (map do j: getAt i j m done [0..length cols-1]); |
54 ColM cols: c = cols[col]; (c is ~double[])[row] := n; | 55 esac; |
55 esac, | 56 |
56 get isRowMajor? () = | 57 /* |
57 case d of | 58 setAt row col n m = //!!! dangerous, could modify copies -- should it be allowed? |
58 RowM _: true; | 59 case m of |
59 ColM _: false; | 60 RowM rows: r = rows[row]; (vec.data r)[col] := n; |
60 esac, | 61 ColM cols: c = cols[col]; (vec.data c)[row] := n; |
61 }; | 62 esac; |
63 */ | |
64 | |
65 isRowMajor? m = | |
66 case m of | |
67 RowM _: true; | |
68 ColM _: false; | |
69 esac; | |
62 | 70 |
63 newColMajorStorage { rows, columns } = | 71 newColMajorStorage { rows, columns } = |
64 if rows < 1 then array [] | 72 if rows < 1 then array [] |
65 else array (map \(vec.zeros rows) [1..columns]) | 73 else array (map \(vec.zeros rows) [1..columns]) |
66 fi; | 74 fi; |
67 | 75 |
68 zeroMatrix { rows, columns } = | 76 zeroMatrix { rows, columns } = |
69 make (ColM (newColMajorStorage { rows, columns })); | 77 ColM (newColMajorStorage { rows, columns }); |
70 | 78 |
71 zeroMatrixWithTypeOf m { rows, columns } = | 79 zeroMatrixWithTypeOf m { rows, columns } = |
72 if m.isRowMajor? then | 80 if isRowMajor? m then |
73 make (RowM (newColMajorStorage { rows = columns, columns = rows })); | 81 RowM (newColMajorStorage { rows = columns, columns = rows }); |
74 else | 82 else |
75 make (ColM (newColMajorStorage { rows, columns })); | 83 ColM (newColMajorStorage { rows, columns }); |
76 fi; | 84 fi; |
85 | |
86 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 }; | |
77 | 87 |
78 generate f { rows, columns } = | 88 generate f { rows, columns } = |
79 (m = newColMajorStorage { rows, columns }; | 89 if rows < 1 or columns < 1 then zeroSizeMatrix () |
80 for [0..columns-1] do col: | 90 else |
81 for [0..rows-1] do row: | 91 m = array (map \(new double[rows]) [1..columns]); |
82 m[col][row] := f row col; | 92 for [0..columns-1] do col: |
93 for [0..rows-1] do row: | |
94 m[col][row] := f row col; | |
95 done; | |
83 done; | 96 done; |
84 done; | 97 ColM (array (map vec.vector m)) |
85 make (ColM m)); | 98 fi; |
86 | 99 |
87 constMatrix n = generate do row col: n done; | 100 constMatrix n = generate do row col: n done; |
88 randomMatrix = generate do row col: Math#random() done; | 101 randomMatrix = generate do row col: Math#random() done; |
89 identityMatrix = constMatrix 1; | 102 identityMatrix = constMatrix 1; |
90 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 }; | |
91 | |
92 width m = m.size.columns; | |
93 height m = m.size.rows; | |
94 | 103 |
95 transposed m = | 104 transposed m = |
96 make | 105 case m of |
97 (case m.data of | 106 RowM d: ColM d; |
98 RowM d: ColM d; | 107 ColM d: RowM d; |
99 ColM d: RowM d; | 108 esac; |
100 esac); | |
101 | 109 |
102 flipped m = | 110 flipped m = |
103 if m.isRowMajor? then | 111 if isRowMajor? m then |
104 generate do row col: m.getAt row col done m.size; | 112 generate do row col: getAt row col m done (size m); |
105 else | 113 else |
106 transposed | 114 transposed |
107 (generate do row col: m.getAt col row done | 115 (generate do row col: getAt col row m done |
108 { rows = m.size.columns, columns = m.size.rows }); | 116 { rows = (width m), columns = (height m) }); |
109 fi; | 117 fi; |
110 | 118 |
111 toRowMajor m = | 119 toRowMajor m = |
112 if m.isRowMajor? then m else flipped m fi; | 120 if isRowMajor? m then m else flipped m fi; |
113 | 121 |
114 toColumnMajor m = | 122 toColumnMajor m = |
115 if not m.isRowMajor? then m else flipped m fi; | 123 if not isRowMajor? m then m else flipped m fi; |
116 | 124 |
117 // Matrices with different storage order but the same contents are | 125 // Matrices with different storage order but the same contents are |
118 // equal (but comparing them is slow) | 126 // equal (but comparing them is slow) |
119 equal m1 m2 = | 127 equal m1 m2 = |
120 if m1.size != m2.size then false | 128 if size m1 != size m2 then false |
121 elif m1.isRowMajor? != m2.isRowMajor? then equal (flipped m1) m2; | 129 elif isRowMajor? m1 != isRowMajor? m2 then equal (flipped m1) m2; |
122 else | 130 else |
123 compare d1 d2 = all id (map2 vec.equal d1 d2); | 131 compare d1 d2 = all id (map2 vec.equal d1 d2); |
124 case m1.data of | 132 case m1 of |
125 RowM d1: case m2.data of RowM d2: compare d1 d2; _: false; esac; | 133 RowM d1: case m2 of RowM d2: compare d1 d2; _: false; esac; |
126 ColM d1: case m2.data of ColM d2: compare d1 d2; _: false; esac; | 134 ColM d1: case m2 of ColM d2: compare d1 d2; _: false; esac; |
127 esac | 135 esac |
128 fi; | 136 fi; |
129 | 137 |
138 /*!!! not needed now it's immutable? | |
130 copyOf m = | 139 copyOf m = |
131 (copyOfData d = (array (map vec.copyOf d)); | 140 (copyOfData d = (array (map vec.copyOf d)); |
132 make | 141 case m of |
133 (case m.data of | 142 RowM d: RowM (copyOfData d); |
134 RowM d: RowM (copyOfData d); | 143 ColM d: ColM (copyOfData d); |
135 ColM d: ColM (copyOfData d); | 144 esac); |
136 esac)); | 145 */ |
137 | 146 |
138 newMatrix type data = //!!! NB does not copy data | 147 newMatrix type data = //!!! NB does not copy data |
139 (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac; | 148 (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac; |
140 if empty? data or block.empty? (head data) | 149 if empty? data or vec.empty? (head data) |
141 then zeroSizeMatrix () | 150 then zeroSizeMatrix () |
142 else make (tagger (array (map block.data data))) | 151 else tagger (array data) |
143 fi); | 152 fi); |
144 | 153 |
145 newRowVector data = //!!! NB does not copy data | 154 newRowVector data = //!!! NB does not copy data |
146 make (RowM (array [block.data data])); | 155 RowM (array [data]); |
147 | 156 |
148 newColumnVector data = //!!! NB does not copy data | 157 newColumnVector data = //!!! NB does not copy data |
149 make (ColM (array [block.data data])); | 158 ColM (array [data]); |
150 | 159 |
151 scaled factor m = //!!! v inefficient | 160 scaled factor m = //!!! v inefficient |
152 generate do row col: factor * m.getAt row col done m.size; | 161 generate do row col: factor * (getAt row col m) done (size m); |
153 | 162 |
154 sum' m1 m2 = | 163 sum' m1 m2 = |
155 if m1.size != m2.size | 164 if (size m1) != (size m2) |
156 then failWith "Matrices are not the same size: \(m1.size), \(m2.size)"; | 165 then failWith "Matrices are not the same size: \(size m1), \(size m2)"; |
157 else | 166 else |
158 generate do row col: m1.getAt row col + m2.getAt row col done m1.size; | 167 generate do row col: getAt row col m1 + getAt row col m2 done (size m1); |
159 fi; | 168 fi; |
160 | 169 |
161 product m1 m2 = | 170 product m1 m2 = |
162 if m1.size.columns != m2.size.rows | 171 if (size m1).columns != (size m2).rows |
163 then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)"; | 172 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns != (size m2).rows)"; |
164 else | 173 else |
165 generate do row col: | 174 generate do row col: |
166 bf.sum (bf.multiply (m1.getRow row) (m2.getColumn col)) | 175 bf.sum (bf.multiply (getRow row m1) (getColumn col m2)) |
167 done { rows = m1.size.rows, columns = m2.size.columns } | 176 done { rows = (size m1).rows, columns = (size m2).columns } |
168 fi; | 177 fi; |
169 | 178 |
170 asRows m = | 179 asRows m = |
171 map m.getRow [0 .. m.size.rows - 1]; | 180 map do i: getRow i m done [0 .. (height m) - 1]; |
172 | 181 |
173 asColumns m = | 182 asColumns m = |
174 map m.getColumn [0 .. m.size.columns - 1]; | 183 map do i: getColumn i m done [0 .. (width m) - 1]; |
175 | 184 |
176 concatAgainstGrain tagger getter counter mm = | 185 concatAgainstGrain tagger getter counter mm = |
177 (n = counter (head mm).size; | 186 (n = counter (size (head mm)); |
178 make (tagger (array | 187 tagger (array |
179 (map do i: | 188 (map do i: |
180 block.data (block.concat (map do m: getter m i done mm)) | 189 vec.concat (map (getter i) mm) |
181 done [0..n-1])))); | 190 done [0..n-1]))); |
182 | 191 |
183 concatWithGrain tagger getter counter mm = | 192 concatWithGrain tagger getter counter mm = |
184 make (tagger (array | 193 tagger (array |
185 (concat | 194 (concat |
186 (map do m: | 195 (map do m: |
187 n = counter m.size; | 196 n = counter (size m); |
188 map do i: block.data (getter m i) done [0..n-1] | 197 map do i: getter i m done [0..n-1] |
189 done mm)))); | 198 done mm))); |
190 | 199 |
191 checkDimensionsFor direction first mm = | 200 checkDimensionsFor direction first mm = |
192 (counter = if direction == Horizontal () then (.rows) else (.columns) fi; | 201 (counter = if direction == Horizontal () then (.rows) else (.columns) fi; |
193 n = counter first.size; | 202 n = counter (size first); |
194 if not (all id (map do m: counter m.size == n done mm)) then | 203 if not (all id (map do m: counter (size m) == n done mm)) then |
195 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter m.size done mm) not all of which are \(n))"; | 204 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))"; |
196 fi); | 205 fi); |
197 | 206 |
198 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence | 207 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence |
199 //!!! would this be better as separate concatHorizontal/concatVertical functions? | 208 //!!! would this be better as separate concatHorizontal/concatVertical functions? |
200 case mm of | 209 case mm of |
201 first::rest: | 210 first::rest: |
202 checkDimensionsFor direction first mm; | 211 checkDimensionsFor direction first mm; |
203 row = first.isRowMajor?; | 212 row = isRowMajor? first; |
204 // horizontal, row-major: against grain with rows | 213 // horizontal, row-major: against grain with rows |
205 // horizontal, col-major: with grain with cols | 214 // horizontal, col-major: with grain with cols |
206 // vertical, row-major: with grain with rows | 215 // vertical, row-major: with grain with rows |
207 // vertical, col-major: against grain with cols | 216 // vertical, col-major: against grain with cols |
208 case direction of | 217 case direction of |
209 Horizontal (): | 218 Horizontal (): |
210 if row then concatAgainstGrain RowM (.getRow) (.rows) mm; | 219 if row then concatAgainstGrain RowM getRow (.rows) mm; |
211 else concatWithGrain ColM (.getColumn) (.columns) mm; | 220 else concatWithGrain ColM getColumn (.columns) mm; |
212 fi; | 221 fi; |
213 Vertical (): | 222 Vertical (): |
214 if row then concatWithGrain RowM (.getRow) (.rows) mm; | 223 if row then concatWithGrain RowM getRow (.rows) mm; |
215 else concatAgainstGrain ColM (.getColumn) (.columns) mm; | 224 else concatAgainstGrain ColM getColumn (.columns) mm; |
216 fi; | 225 fi; |
217 esac; | 226 esac; |
218 [single]: single; | 227 [single]: single; |
219 _: zeroSizeMatrix (); | 228 _: zeroSizeMatrix (); |
220 esac; | 229 esac; |
221 | 230 |
222 rowSlice start count m = //!!! doc: storage order same as input | 231 rowSlice start count m = //!!! doc: storage order same as input |
223 if m.isRowMajor? then | 232 if isRowMajor? m then |
224 make (RowM (array (map (block.data . m.getRow) [start .. start + count - 1]))) | 233 RowM (array (map ((flip getRow) m) [start .. start + count - 1])) |
225 else | 234 else |
226 make (ColM (array (map (block.data . (block.rangeOf start count)) (asColumns m)))) | 235 ColM (array (map (vec.rangeOf start count) (asColumns m))) |
227 fi; | 236 fi; |
228 | 237 |
229 columnSlice start count m = //!!! doc: storage order same as input | 238 columnSlice start count m = //!!! doc: storage order same as input |
230 if not m.isRowMajor? then | 239 if not isRowMajor? m then |
231 make (ColM (array (map (block.data . m.getColumn) [start .. start + count - 1]))) | 240 ColM (array (map ((flip getColumn) m) [start .. start + count - 1])) |
232 else | 241 else |
233 make (RowM (array (map (block.data . (block.rangeOf start count)) (asRows m)))) | 242 RowM (array (map (vec.rangeOf start count) (asRows m))) |
234 fi; | 243 fi; |
235 | 244 |
236 resizedTo newsize m = | 245 resizedTo newsize m = |
237 (if newsize == m.size then | 246 (if newsize == (size m) then |
238 m | 247 m |
239 elif m.size.rows == 0 or m.size.columns == 0 then | 248 elif (height m) == 0 or (width m) == 0 then |
240 zeroMatrixWithTypeOf m newsize; | 249 zeroMatrixWithTypeOf m newsize; |
241 else | 250 else |
242 growrows = newsize.rows - m.size.rows; | 251 growrows = newsize.rows - (height m); |
243 growcols = newsize.columns - m.size.columns; | 252 growcols = newsize.columns - (width m); |
244 rowm = m.isRowMajor?; | 253 rowm = isRowMajor? m; |
245 resizedTo newsize | 254 resizedTo newsize |
246 if rowm and growrows < 0 then | 255 if rowm and growrows < 0 then |
247 rowSlice 0 newsize.rows m | 256 rowSlice 0 newsize.rows m |
248 elif (not rowm) and growcols < 0 then | 257 elif (not rowm) and growcols < 0 then |
249 columnSlice 0 newsize.columns m | 258 columnSlice 0 newsize.columns m |
252 elif growcols < 0 then | 261 elif growcols < 0 then |
253 columnSlice 0 newsize.columns m | 262 columnSlice 0 newsize.columns m |
254 else | 263 else |
255 if growrows > 0 then | 264 if growrows > 0 then |
256 concat (Vertical ()) | 265 concat (Vertical ()) |
257 [m, zeroMatrixWithTypeOf m (m.size with { rows = growrows })] | 266 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })] |
258 else | 267 else |
259 concat (Horizontal ()) | 268 concat (Horizontal ()) |
260 [m, zeroMatrixWithTypeOf m (m.size with { columns = growcols })] | 269 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })] |
261 fi | 270 fi |
262 fi | 271 fi |
263 fi); | 272 fi); |
264 | 273 |
265 { | 274 { |
275 size, | |
276 width, | |
277 height, | |
278 getAt, | |
279 getColumn, | |
280 getRow, | |
281 // setAt, | |
282 isRowMajor?, | |
283 generate, | |
284 constMatrix, | |
285 randomMatrix, | |
286 zeroMatrix, | |
287 identityMatrix, | |
288 zeroSizeMatrix, | |
289 equal, | |
290 // copyOf, | |
291 transposed, | |
292 flipped, | |
293 toRowMajor, | |
294 toColumnMajor, | |
295 scaled, | |
296 resizedTo, | |
297 asRows, | |
298 asColumns, | |
299 sum = sum', | |
300 product, | |
301 concat, | |
302 rowSlice, | |
303 columnSlice, | |
304 newMatrix, | |
305 newRowVector, | |
306 newColumnVector, | |
307 } | |
308 as | |
309 { | |
310 //!!! check whether these are right to be .selector rather than just selector | |
311 | |
312 size is matrix -> { .rows is number, .columns is number }, | |
313 width is matrix -> number, | |
314 height is matrix -> number, | |
315 getAt is number -> number -> matrix -> number, | |
316 getColumn is number -> matrix -> vector, | |
317 getRow is number -> matrix -> vector, | |
318 // setAt is number -> number -> number -> matrix -> (), //!!! lose? | |
319 isRowMajor? is matrix -> boolean, | |
266 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix, | 320 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix, |
267 constMatrix is number -> { .rows is number, .columns is number } -> matrix, | 321 constMatrix is number -> { .rows is number, .columns is number } -> matrix, |
268 randomMatrix is { .rows is number, .columns is number } -> matrix, | 322 randomMatrix is { .rows is number, .columns is number } -> matrix, |
269 zeroMatrix is { .rows is number, .columns is number } -> matrix, | 323 zeroMatrix is { .rows is number, .columns is number } -> matrix, |
270 identityMatrix is { .rows is number, .columns is number } -> matrix, | 324 identityMatrix is { .rows is number, .columns is number } -> matrix, |
271 zeroSizeMatrix is () -> matrix, | 325 zeroSizeMatrix is () -> matrix, |
272 width is matrix -> number, | |
273 height is matrix -> number, | |
274 equal is matrix -> matrix -> boolean, | 326 equal is matrix -> matrix -> boolean, |
275 copyOf is matrix -> matrix, | 327 // copyOf is matrix -> matrix, |
276 transposed is matrix -> matrix, | 328 transposed is matrix -> matrix, |
277 flipped is matrix -> matrix, | 329 flipped is matrix -> matrix, |
278 toRowMajor is matrix -> matrix, | 330 toRowMajor is matrix -> matrix, |
279 toColumnMajor is matrix -> matrix, | 331 toColumnMajor is matrix -> matrix, |
280 scaled is number -> matrix -> matrix, | 332 scaled is number -> matrix -> matrix, |
281 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix, | 333 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix, |
282 asRows is matrix -> list<block>, | 334 asRows is matrix -> list<vector>, |
283 asColumns is matrix -> list<block>, | 335 asColumns is matrix -> list<vector>, |
284 sum is matrix -> matrix -> matrix = sum', | 336 sum is matrix -> matrix -> matrix, |
285 product is matrix -> matrix -> matrix, | 337 product is matrix -> matrix -> matrix, |
286 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix, | 338 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix, |
287 rowSlice is number -> number -> matrix -> matrix, | 339 rowSlice is number -> number -> matrix -> matrix, |
288 columnSlice is number -> number -> matrix -> matrix, | 340 columnSlice is number -> number -> matrix -> matrix, |
289 newMatrix is (ColumnMajor () | RowMajor ()) -> list<block> -> matrix, | 341 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix, |
290 newRowVector is block -> matrix, | 342 newRowVector is vector -> matrix, |
291 newColumnVector is block -> matrix, | 343 newColumnVector is vector -> matrix, |
292 } | 344 } |
293 | 345 |