Chris@5
|
1
|
Chris@273
|
2 module yetilab.matrix;
|
Chris@94
|
3
|
Chris@214
|
4 // A matrix is an array of vectors.
|
Chris@94
|
5
|
Chris@101
|
6 // A matrix can be stored in either column-major (the default) or
|
Chris@101
|
7 // row-major format. Storage order is an efficiency concern only:
|
Chris@101
|
8 // every API function operating on matrix objects will return the same
|
Chris@101
|
9 // result regardless of storage order. (The transpose function just
|
Chris@101
|
10 // switches the row/column order without moving the elements.)
|
Chris@18
|
11
|
Chris@275
|
12 // The matrix representation does not take into account different
|
Chris@275
|
13 // forms of zero-width or zero-height matrix. That is, all matrices of
|
Chris@275
|
14 // zero width or height are equal to each other and to the zero-size
|
Chris@275
|
15 // matrix -- they can't be distinguished.
|
Chris@275
|
16
|
Chris@273
|
17 vec = load yetilab.vector;
|
Chris@222
|
18 bf = load yetilab.vector.blockfuncs;
|
Chris@9
|
19
|
Chris@273
|
20 load yetilab.vector.type;
|
Chris@273
|
21 load yetilab.matrix.type;
|
Chris@195
|
22
|
Chris@208
|
23 size m =
|
Chris@208
|
24 case m of
|
Chris@234
|
25 DenseRows r:
|
Chris@208
|
26 major = length r;
|
Chris@208
|
27 {
|
Chris@208
|
28 rows = major,
|
Chris@214
|
29 columns = if major > 0 then vec.length r[0] else 0 fi,
|
Chris@208
|
30 };
|
Chris@234
|
31 DenseCols c:
|
Chris@208
|
32 major = length c;
|
Chris@208
|
33 {
|
Chris@214
|
34 rows = if major > 0 then vec.length c[0] else 0 fi,
|
Chris@208
|
35 columns = major,
|
Chris@208
|
36 };
|
Chris@237
|
37 SparseCSR { values, indices, pointers, extent }:
|
Chris@234
|
38 {
|
Chris@236
|
39 rows = (length pointers) - 1,
|
Chris@237
|
40 columns = extent
|
Chris@234
|
41 };
|
Chris@237
|
42 SparseCSC { values, indices, pointers, extent }:
|
Chris@234
|
43 {
|
Chris@237
|
44 rows = extent,
|
Chris@236
|
45 columns = (length pointers) - 1
|
Chris@234
|
46 };
|
Chris@208
|
47 esac;
|
Chris@208
|
48
|
Chris@208
|
49 width m = (size m).columns;
|
Chris@208
|
50 height m = (size m).rows;
|
Chris@208
|
51
|
Chris@249
|
52 nonZeroValues m =
|
Chris@249
|
53 (nz d =
|
Chris@242
|
54 sum
|
Chris@242
|
55 (map do v:
|
Chris@242
|
56 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
|
Chris@242
|
57 done d);
|
Chris@242
|
58 case m of
|
Chris@249
|
59 DenseRows d: nz d;
|
Chris@249
|
60 DenseCols d: nz d;
|
Chris@249
|
61 SparseCSR d: vec.length d.values;
|
Chris@249
|
62 SparseCSC d: vec.length d.values;
|
Chris@242
|
63 esac);
|
Chris@242
|
64
|
Chris@249
|
65 density m =
|
Chris@249
|
66 ({ rows, columns } = size m;
|
Chris@249
|
67 cells = rows * columns;
|
Chris@249
|
68 (nonZeroValues m) / cells);
|
Chris@249
|
69
|
Chris@236
|
70 sparseSlice n d =
|
Chris@236
|
71 (start = d.pointers[n];
|
Chris@236
|
72 end = d.pointers[n+1];
|
Chris@236
|
73 {
|
Chris@236
|
74 values = vec.slice d.values start end,
|
Chris@239
|
75 indices = slice d.indices start end,
|
Chris@236
|
76 });
|
Chris@236
|
77
|
Chris@252
|
78 nonEmptySlices d =
|
Chris@252
|
79 (ne = array [];
|
Chris@252
|
80 for [0..length d.pointers - 2] do i:
|
Chris@252
|
81 if d.pointers[i] != d.pointers[i+1] then
|
Chris@252
|
82 push ne i
|
Chris@252
|
83 fi
|
Chris@252
|
84 done;
|
Chris@252
|
85 ne);
|
Chris@252
|
86
|
Chris@236
|
87 fromSlice n m d =
|
Chris@236
|
88 (slice = sparseSlice n d;
|
Chris@236
|
89 var v = 0;
|
Chris@236
|
90 for [0..length slice.indices - 1] do i:
|
Chris@236
|
91 if slice.indices[i] == m then
|
Chris@260
|
92 v := vec.at slice.values i;
|
Chris@236
|
93 fi
|
Chris@236
|
94 done;
|
Chris@236
|
95 v);
|
Chris@236
|
96
|
Chris@236
|
97 filledSlice n d =
|
Chris@236
|
98 (slice = sparseSlice n d;
|
Chris@236
|
99 dslice = new double[d.extent];
|
Chris@239
|
100 for [0..length slice.indices - 1] do i:
|
Chris@260
|
101 dslice[slice.indices[i]] := vec.at slice.values i;
|
Chris@239
|
102 done;
|
Chris@236
|
103 vec.vector dslice);
|
Chris@236
|
104
|
Chris@260
|
105 at' m row col =
|
Chris@208
|
106 case m of
|
Chris@260
|
107 DenseRows rows: r = rows[row]; vec.at r col;
|
Chris@260
|
108 DenseCols cols: c = cols[col]; vec.at c row;
|
Chris@236
|
109 SparseCSR data: fromSlice row col data;
|
Chris@236
|
110 SparseCSC data: fromSlice col row data;
|
Chris@208
|
111 esac;
|
Chris@208
|
112
|
Chris@208
|
113 getColumn j m =
|
Chris@208
|
114 case m of
|
Chris@234
|
115 DenseCols cols: cols[j];
|
Chris@236
|
116 SparseCSC data: filledSlice j data;
|
Chris@260
|
117 _: vec.fromList (map do i: at' m i j done [0..height m - 1]);
|
Chris@208
|
118 esac;
|
Chris@208
|
119
|
Chris@208
|
120 getRow i m =
|
Chris@208
|
121 case m of
|
Chris@234
|
122 DenseRows rows: rows[i];
|
Chris@236
|
123 SparseCSR data: filledSlice i data;
|
Chris@260
|
124 _: vec.fromList (map do j: at' m i j done [0..width m - 1]);
|
Chris@208
|
125 esac;
|
Chris@208
|
126
|
Chris@257
|
127 asRows m =
|
Chris@257
|
128 map do i: getRow i m done [0 .. (height m) - 1];
|
Chris@257
|
129
|
Chris@257
|
130 asColumns m =
|
Chris@257
|
131 map do i: getColumn i m done [0 .. (width m) - 1];
|
Chris@257
|
132
|
Chris@208
|
133 isRowMajor? m =
|
Chris@208
|
134 case m of
|
Chris@234
|
135 DenseRows _: true;
|
Chris@234
|
136 DenseCols _: false;
|
Chris@234
|
137 SparseCSR _: true;
|
Chris@234
|
138 SparseCSC _: false;
|
Chris@234
|
139 esac;
|
Chris@234
|
140
|
Chris@234
|
141 isSparse? m =
|
Chris@234
|
142 case m of
|
Chris@234
|
143 DenseRows _: false;
|
Chris@234
|
144 DenseCols _: false;
|
Chris@234
|
145 SparseCSR _: true;
|
Chris@234
|
146 SparseCSC _: true;
|
Chris@208
|
147 esac;
|
Chris@94
|
148
|
Chris@261
|
149 typeOf m =
|
Chris@261
|
150 if isRowMajor? m then RowMajor ()
|
Chris@261
|
151 else ColumnMajor ()
|
Chris@261
|
152 fi;
|
Chris@261
|
153
|
Chris@261
|
154 flippedTypeOf m =
|
Chris@261
|
155 if isRowMajor? m then ColumnMajor ()
|
Chris@261
|
156 else RowMajor ()
|
Chris@261
|
157 fi;
|
Chris@261
|
158
|
Chris@244
|
159 newColumnMajorStorage { rows, columns } =
|
Chris@97
|
160 if rows < 1 then array []
|
Chris@214
|
161 else array (map \(vec.zeros rows) [1..columns])
|
Chris@97
|
162 fi;
|
Chris@94
|
163
|
Chris@98
|
164 zeroMatrix { rows, columns } =
|
Chris@244
|
165 DenseCols (newColumnMajorStorage { rows, columns });
|
Chris@201
|
166
|
Chris@201
|
167 zeroMatrixWithTypeOf m { rows, columns } =
|
Chris@208
|
168 if isRowMajor? m then
|
Chris@244
|
169 DenseRows (newColumnMajorStorage { rows = columns, columns = rows });
|
Chris@201
|
170 else
|
Chris@244
|
171 DenseCols (newColumnMajorStorage { rows, columns });
|
Chris@201
|
172 fi;
|
Chris@5
|
173
|
Chris@214
|
174 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
|
Chris@214
|
175
|
Chris@275
|
176 isZeroSize? m = (width m == 0 or height m == 0);
|
Chris@275
|
177
|
Chris@98
|
178 generate f { rows, columns } =
|
Chris@214
|
179 if rows < 1 or columns < 1 then zeroSizeMatrix ()
|
Chris@214
|
180 else
|
Chris@214
|
181 m = array (map \(new double[rows]) [1..columns]);
|
Chris@214
|
182 for [0..columns-1] do col:
|
Chris@214
|
183 for [0..rows-1] do row:
|
Chris@214
|
184 m[col][row] := f row col;
|
Chris@214
|
185 done;
|
Chris@5
|
186 done;
|
Chris@234
|
187 DenseCols (array (map vec.vector m))
|
Chris@214
|
188 fi;
|
Chris@5
|
189
|
Chris@255
|
190 swapij =
|
Chris@255
|
191 map do { i, j, v }: { i = j, j = i, v } done;
|
Chris@255
|
192
|
Chris@257
|
193 //!!! should use { row = , column = , value = } instead of i, j, v?
|
Chris@235
|
194 enumerateSparse m =
|
Chris@240
|
195 (enumerate { values, indices, pointers } =
|
Chris@240
|
196 concat
|
Chris@240
|
197 (map do i:
|
Chris@240
|
198 start = pointers[i];
|
Chris@240
|
199 end = pointers[i+1];
|
Chris@240
|
200 map2 do j v: { i, j, v } done
|
Chris@240
|
201 (slice indices start end)
|
Chris@240
|
202 (vec.list (vec.slice values start end))
|
Chris@240
|
203 done [0..length pointers - 2]);
|
Chris@235
|
204 case m of
|
Chris@255
|
205 SparseCSC d: swapij (enumerate d);
|
Chris@255
|
206 SparseCSR d: enumerate d;
|
Chris@237
|
207 _: [];
|
Chris@235
|
208 esac);
|
Chris@235
|
209
|
Chris@255
|
210 enumerateDense m =
|
Chris@255
|
211 (enumerate d =
|
Chris@255
|
212 concat
|
Chris@255
|
213 (map do i:
|
Chris@255
|
214 vv = d[i];
|
Chris@255
|
215 map2 do j v: { i, j, v } done
|
Chris@255
|
216 [0..vec.length vv - 1]
|
Chris@255
|
217 (vec.list vv);
|
Chris@255
|
218 done [0..length d - 1]);
|
Chris@255
|
219 case m of
|
Chris@255
|
220 DenseCols c: swapij (enumerate c);
|
Chris@255
|
221 DenseRows r: enumerate r;
|
Chris@255
|
222 _: [];
|
Chris@255
|
223 esac);
|
Chris@255
|
224
|
Chris@255
|
225 enumerate m =
|
Chris@255
|
226 if isSparse? m then enumerateSparse m else enumerateDense m fi;
|
Chris@255
|
227
|
Chris@261
|
228 // Make a sparse matrix from entries whose i, j values are known to be
|
Chris@261
|
229 // within range
|
Chris@238
|
230 makeSparse type size data =
|
Chris@244
|
231 (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
|
Chris@238
|
232 ordered =
|
Chris@238
|
233 sortBy do a b:
|
Chris@238
|
234 if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
|
Chris@238
|
235 done
|
Chris@238
|
236 (map
|
Chris@238
|
237 if isRow then
|
Chris@238
|
238 do { i, j, v }: { maj = i, min = j, v } done;
|
Chris@238
|
239 else
|
Chris@238
|
240 do { i, j, v }: { maj = j, min = i, v } done;
|
Chris@238
|
241 fi
|
Chris@255
|
242 (filter do d: d.v != 0 done data));
|
Chris@238
|
243 tagger = if isRow then SparseCSR else SparseCSC fi;
|
Chris@238
|
244 majorSize = if isRow then size.rows else size.columns fi;
|
Chris@238
|
245 minorSize = if isRow then size.columns else size.rows fi;
|
Chris@251
|
246 pointers = array [0];
|
Chris@251
|
247 setArrayCapacity pointers (size.rows + 1);
|
Chris@251
|
248 fillPointers n i data =
|
Chris@251
|
249 if n < majorSize then
|
Chris@251
|
250 case data of
|
Chris@238
|
251 d::rest:
|
Chris@251
|
252 (for [n..d-1] \(push pointers i);
|
Chris@251
|
253 fillPointers d (i+1) rest);
|
Chris@251
|
254 _:
|
Chris@251
|
255 for [n..majorSize-1] \(push pointers i);
|
Chris@238
|
256 esac;
|
Chris@238
|
257 fi;
|
Chris@251
|
258 fillPointers 0 0 (map (.maj) ordered);
|
Chris@238
|
259 tagger {
|
Chris@238
|
260 values = vec.fromList (map (.v) ordered),
|
Chris@238
|
261 indices = array (map (.min) ordered),
|
Chris@251
|
262 pointers,
|
Chris@238
|
263 extent = minorSize,
|
Chris@238
|
264 });
|
Chris@238
|
265
|
Chris@261
|
266 // Make a sparse matrix from entries that may contain out-of-range
|
Chris@261
|
267 // cells which need to be filtered out. This is the public API for
|
Chris@261
|
268 // makeSparse and is also used to discard out-of-range cells from
|
Chris@261
|
269 // resizedTo.
|
Chris@261
|
270 newSparseMatrix type size data =
|
Chris@261
|
271 makeSparse type size
|
Chris@261
|
272 (filter
|
Chris@261
|
273 do { i, j, v }:
|
Chris@261
|
274 i == int i and i >= 0 and i < size.rows and
|
Chris@261
|
275 j == int j and j >= 0 and j < size.columns
|
Chris@261
|
276 done data);
|
Chris@261
|
277
|
Chris@243
|
278 toSparse m =
|
Chris@238
|
279 if isSparse? m then m
|
Chris@238
|
280 else
|
Chris@261
|
281 makeSparse (typeOf m) (size m) (enumerateDense m);
|
Chris@238
|
282 fi;
|
Chris@238
|
283
|
Chris@238
|
284 toDense m =
|
Chris@238
|
285 if not (isSparse? m) then m
|
Chris@238
|
286 elif isRowMajor? m then
|
Chris@238
|
287 DenseRows (array (map do row: getRow row m done [0..height m - 1]));
|
Chris@238
|
288 else
|
Chris@238
|
289 DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
|
Chris@238
|
290 fi;
|
Chris@235
|
291
|
Chris@20
|
292 constMatrix n = generate do row col: n done;
|
Chris@20
|
293 randomMatrix = generate do row col: Math#random() done;
|
Chris@5
|
294 identityMatrix = constMatrix 1;
|
Chris@5
|
295
|
Chris@100
|
296 transposed m =
|
Chris@208
|
297 case m of
|
Chris@234
|
298 DenseRows d: DenseCols d;
|
Chris@234
|
299 DenseCols d: DenseRows d;
|
Chris@236
|
300 SparseCSR d: SparseCSC d;
|
Chris@236
|
301 SparseCSC d: SparseCSR d;
|
Chris@208
|
302 esac;
|
Chris@100
|
303
|
Chris@100
|
304 flipped m =
|
Chris@235
|
305 if isSparse? m then
|
Chris@261
|
306 makeSparse (flippedTypeOf m) (size m) (enumerateSparse m)
|
Chris@100
|
307 else
|
Chris@235
|
308 if isRowMajor? m then
|
Chris@260
|
309 generate do row col: at' m row col done (size m);
|
Chris@235
|
310 else
|
Chris@235
|
311 transposed
|
Chris@260
|
312 (generate do row col: at' m col row done
|
Chris@235
|
313 { rows = (width m), columns = (height m) });
|
Chris@235
|
314 fi
|
Chris@100
|
315 fi;
|
Chris@100
|
316
|
Chris@161
|
317 toRowMajor m =
|
Chris@208
|
318 if isRowMajor? m then m else flipped m fi;
|
Chris@161
|
319
|
Chris@161
|
320 toColumnMajor m =
|
Chris@208
|
321 if not isRowMajor? m then m else flipped m fi;
|
Chris@161
|
322
|
Chris@238
|
323 equal'' comparator vecComparator m1 m2 =
|
Chris@238
|
324 // Prerequisite: m1 and m2 have same sparse-p and storage order
|
Chris@241
|
325 (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
|
Chris@238
|
326 compareSparse d1 d2 =
|
Chris@238
|
327 d1.extent == d2.extent and
|
Chris@238
|
328 vecComparator d1.values d2.values and
|
Chris@241
|
329 d1.indices == d2.indices and
|
Chris@241
|
330 d1.pointers == d2.pointers;
|
Chris@238
|
331 case m1 of
|
Chris@238
|
332 DenseRows d1:
|
Chris@238
|
333 case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
|
Chris@238
|
334 DenseCols d1:
|
Chris@238
|
335 case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
|
Chris@238
|
336 SparseCSR d1:
|
Chris@238
|
337 case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
|
Chris@238
|
338 SparseCSC d1:
|
Chris@238
|
339 case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
|
Chris@238
|
340 esac);
|
Chris@238
|
341
|
Chris@238
|
342 equal' comparator vecComparator m1 m2 =
|
Chris@275
|
343 if isZeroSize? m1 and isZeroSize? m2 then
|
Chris@275
|
344 true
|
Chris@275
|
345 elif size m1 != size m2 then
|
Chris@226
|
346 false
|
Chris@226
|
347 elif isRowMajor? m1 != isRowMajor? m2 then
|
Chris@238
|
348 equal' comparator vecComparator (flipped m1) m2;
|
Chris@238
|
349 elif isSparse? m1 != isSparse? m2 then
|
Chris@238
|
350 if isSparse? m1 then
|
Chris@243
|
351 equal' comparator vecComparator m1 (toSparse m2)
|
Chris@238
|
352 else
|
Chris@243
|
353 equal' comparator vecComparator (toSparse m1) m2
|
Chris@238
|
354 fi
|
Chris@100
|
355 else
|
Chris@238
|
356 equal'' comparator vecComparator m1 m2
|
Chris@100
|
357 fi;
|
Chris@97
|
358
|
Chris@228
|
359 // Compare matrices using the given comparator for individual cells.
|
Chris@228
|
360 // Note that matrices with different storage order but the same
|
Chris@228
|
361 // contents are equal, although comparing them is slow.
|
Chris@249
|
362 //!!! Document the fact that sparse matrices can only be equal if they
|
Chris@249
|
363 // have the same set of non-zero cells (regardless of comparator used)
|
Chris@228
|
364 equalUnder comparator =
|
Chris@238
|
365 equal' comparator (vec.equalUnder comparator);
|
Chris@228
|
366
|
Chris@228
|
367 equal =
|
Chris@238
|
368 equal' (==) vec.equal;
|
Chris@226
|
369
|
Chris@163
|
370 newMatrix type data = //!!! NB does not copy data
|
Chris@238
|
371 (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
|
Chris@214
|
372 if empty? data or vec.empty? (head data)
|
Chris@201
|
373 then zeroSizeMatrix ()
|
Chris@208
|
374 else tagger (array data)
|
Chris@96
|
375 fi);
|
Chris@96
|
376
|
Chris@163
|
377 newRowVector data = //!!! NB does not copy data
|
Chris@234
|
378 DenseRows (array [data]);
|
Chris@96
|
379
|
Chris@163
|
380 newColumnVector data = //!!! NB does not copy data
|
Chris@234
|
381 DenseCols (array [data]);
|
Chris@8
|
382
|
Chris@257
|
383 denseLinearOp op m1 m2 =
|
Chris@257
|
384 if isRowMajor? m1 then
|
Chris@261
|
385 newMatrix (typeOf m1)
|
Chris@257
|
386 (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
|
Chris@257
|
387 else
|
Chris@261
|
388 newMatrix (typeOf m1)
|
Chris@257
|
389 (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
|
Chris@257
|
390 fi;
|
Chris@257
|
391
|
Chris@257
|
392 sparseSumOrDifference op m1 m2 =
|
Chris@257
|
393 (h = [:];
|
Chris@257
|
394 for (enumerate m1) do { i, j, v }:
|
Chris@257
|
395 if not (i in h) then h[i] := [:] fi;
|
Chris@257
|
396 h[i][j] := v;
|
Chris@257
|
397 done;
|
Chris@257
|
398 for (enumerate m2) do { i, j, v }:
|
Chris@257
|
399 if not (i in h) then h[i] := [:] fi;
|
Chris@257
|
400 if j in h[i] then h[i][j] := op h[i][j] v;
|
Chris@257
|
401 else h[i][j] := op 0 v;
|
Chris@257
|
402 fi;
|
Chris@257
|
403 done;
|
Chris@257
|
404 entries = concat
|
Chris@257
|
405 (map do i:
|
Chris@257
|
406 kk = keys h[i];
|
Chris@257
|
407 map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
|
Chris@257
|
408 done (keys h));
|
Chris@261
|
409 makeSparse (typeOf m1) (size m1) entries);
|
Chris@257
|
410
|
Chris@98
|
411 sum' m1 m2 =
|
Chris@208
|
412 if (size m1) != (size m2)
|
Chris@208
|
413 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
|
Chris@257
|
414 elif isSparse? m1 and isSparse? m2 then
|
Chris@257
|
415 sparseSumOrDifference (+) m1 m2;
|
Chris@98
|
416 else
|
Chris@257
|
417 denseLinearOp bf.add m1 m2;
|
Chris@98
|
418 fi;
|
Chris@98
|
419
|
Chris@257
|
420 difference m1 m2 =
|
Chris@229
|
421 if (size m1) != (size m2)
|
Chris@229
|
422 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
|
Chris@257
|
423 elif isSparse? m1 and isSparse? m2 then
|
Chris@257
|
424 sparseSumOrDifference (-) m1 m2;
|
Chris@229
|
425 else
|
Chris@257
|
426 denseLinearOp bf.subtract m1 m2;
|
Chris@257
|
427 fi;
|
Chris@257
|
428
|
Chris@257
|
429 scaled factor m =
|
Chris@257
|
430 if isSparse? m then
|
Chris@261
|
431 makeSparse (typeOf m) (size m)
|
Chris@261
|
432 (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
|
Chris@257
|
433 elif isRowMajor? m then
|
Chris@261
|
434 newMatrix (typeOf m) (map (bf.scaled factor) (asRows m));
|
Chris@257
|
435 else
|
Chris@261
|
436 newMatrix (typeOf m) (map (bf.scaled factor) (asColumns m));
|
Chris@229
|
437 fi;
|
Chris@229
|
438
|
Chris@229
|
439 abs' m =
|
Chris@257
|
440 if isSparse? m then
|
Chris@261
|
441 makeSparse (typeOf m) (size m)
|
Chris@261
|
442 (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
|
Chris@257
|
443 elif isRowMajor? m then
|
Chris@261
|
444 newMatrix (typeOf m) (map bf.abs (asRows m));
|
Chris@257
|
445 else
|
Chris@261
|
446 newMatrix (typeOf m) (map bf.abs (asColumns m));
|
Chris@257
|
447 fi;
|
Chris@229
|
448
|
Chris@275
|
449 filter' f m =
|
Chris@258
|
450 if isSparse? m then
|
Chris@261
|
451 makeSparse (typeOf m) (size m)
|
Chris@261
|
452 (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
|
Chris@261
|
453 (enumerate m))
|
Chris@258
|
454 else
|
Chris@258
|
455 vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
|
Chris@258
|
456 if isRowMajor? m then
|
Chris@261
|
457 newMatrix (typeOf m) (map vfilter (asRows m));
|
Chris@258
|
458 else
|
Chris@261
|
459 newMatrix (typeOf m) (map vfilter (asColumns m));
|
Chris@258
|
460 fi;
|
Chris@258
|
461 fi;
|
Chris@258
|
462
|
Chris@249
|
463 sparseProductLeft size m1 m2 =
|
Chris@272
|
464 ({ values, indices, pointers } = case m1 of
|
Chris@272
|
465 SparseCSR d: d;
|
Chris@271
|
466 SparseCSC d: d;
|
Chris@271
|
467 _: failWith "sparseProductLeft called for non-sparse m1";
|
Chris@271
|
468 esac;
|
Chris@272
|
469 rows = isRowMajor? m1;
|
Chris@249
|
470 data = array (map \(new double[size.rows]) [1..size.columns]);
|
Chris@249
|
471 for [0..size.columns - 1] do j':
|
Chris@249
|
472 c = getColumn j' m2;
|
Chris@272
|
473 var p = 0;
|
Chris@272
|
474 for [0..length indices - 1] do ix:
|
Chris@272
|
475 ix == pointers[p+1] loop (p := p + 1);
|
Chris@272
|
476 i = if rows then p else indices[ix] fi;
|
Chris@272
|
477 j = if rows then indices[ix] else p fi;
|
Chris@272
|
478 data[j'][i] := data[j'][i] + (vec.at values ix) * (vec.at c j);
|
Chris@249
|
479 done;
|
Chris@249
|
480 done;
|
Chris@249
|
481 DenseCols (array (map vec.vector (list data))));
|
Chris@249
|
482
|
Chris@249
|
483 sparseProductRight size m1 m2 =
|
Chris@272
|
484 ({ values, indices, pointers } = case m2 of
|
Chris@272
|
485 SparseCSR d: d;
|
Chris@272
|
486 SparseCSC d: d;
|
Chris@272
|
487 _: failWith "sparseProductLeft called for non-sparse m1";
|
Chris@272
|
488 esac;
|
Chris@272
|
489 rows = isRowMajor? m2;
|
Chris@249
|
490 data = array (map \(new double[size.columns]) [1..size.rows]);
|
Chris@249
|
491 for [0..size.rows - 1] do i':
|
Chris@249
|
492 r = getRow i' m1;
|
Chris@272
|
493 var p = 0;
|
Chris@272
|
494 for [0..length indices - 1] do ix:
|
Chris@272
|
495 ix == pointers[p+1] loop (p := p + 1);
|
Chris@272
|
496 i = if rows then p else indices[ix] fi;
|
Chris@272
|
497 j = if rows then indices[ix] else p fi;
|
Chris@272
|
498 data[i'][j] := data[i'][j] + (vec.at values ix) * (vec.at r i);
|
Chris@249
|
499 done;
|
Chris@249
|
500 done;
|
Chris@249
|
501 DenseRows (array (map vec.vector (list data))));
|
Chris@249
|
502
|
Chris@251
|
503 sparseProduct size m1 m2 =
|
Chris@251
|
504 case m2 of
|
Chris@251
|
505 SparseCSC d:
|
Chris@272
|
506 ({ values, indices, pointers } = case m1 of
|
Chris@272
|
507 SparseCSR d1: d1;
|
Chris@272
|
508 SparseCSC d1: d1;
|
Chris@272
|
509 _: failWith "sparseProduct called for non-sparse matrices";
|
Chris@272
|
510 esac;
|
Chris@272
|
511 rows = isRowMajor? m1;
|
Chris@272
|
512 var p = 0;
|
Chris@272
|
513 pindices = new int[length indices];
|
Chris@272
|
514 for [0..length indices - 1] do ix:
|
Chris@272
|
515 ix == pointers[p+1] loop (p := p + 1);
|
Chris@272
|
516 pindices[ix] := p;
|
Chris@272
|
517 done;
|
Chris@251
|
518 entries =
|
Chris@251
|
519 (map do j':
|
Chris@251
|
520 cs = sparseSlice j' d;
|
Chris@252
|
521 hin = mapIntoHash
|
Chris@260
|
522 (at cs.indices) (vec.at cs.values)
|
Chris@252
|
523 [0..length cs.indices - 1];
|
Chris@252
|
524 hout = [:];
|
Chris@272
|
525 for [0..length indices - 1] do ix:
|
Chris@272
|
526 i = if rows then pindices[ix] else indices[ix] fi;
|
Chris@272
|
527 j = if rows then indices[ix] else pindices[ix] fi;
|
Chris@252
|
528 if j in hin then
|
Chris@272
|
529 p = (vec.at values ix) * hin[j];
|
Chris@252
|
530 hout[i] := p + (if i in hout then hout[i] else 0 fi);
|
Chris@272
|
531 fi;
|
Chris@252
|
532 done;
|
Chris@252
|
533 map do i:
|
Chris@252
|
534 { i, j = j', v = hout[i] }
|
Chris@252
|
535 done (keys hout);
|
Chris@252
|
536 done (nonEmptySlices d));
|
Chris@251
|
537 makeSparse (ColumnMajor ()) size (concat entries));
|
Chris@251
|
538 SparseCSR _:
|
Chris@251
|
539 sparseProduct size m1 (flipped m2);
|
Chris@251
|
540 _: failWith "sparseProduct called for non-sparse matrices";
|
Chris@251
|
541 esac;
|
Chris@251
|
542
|
Chris@249
|
543 denseProduct size m1 m2 =
|
Chris@249
|
544 (data = array (map \(new double[size.rows]) [1..size.columns]);
|
Chris@249
|
545 for [0..size.rows - 1] do i:
|
Chris@249
|
546 row = getRow i m1;
|
Chris@249
|
547 for [0..size.columns - 1] do j:
|
Chris@249
|
548 data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
|
Chris@249
|
549 done;
|
Chris@249
|
550 done;
|
Chris@249
|
551 DenseCols (array (map vec.vector (list data))));
|
Chris@249
|
552
|
Chris@98
|
553 product m1 m2 =
|
Chris@208
|
554 if (size m1).columns != (size m2).rows
|
Chris@246
|
555 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
|
Chris@249
|
556 else
|
Chris@249
|
557 size = { rows = (size m1).rows, columns = (size m2).columns };
|
Chris@249
|
558 if isSparse? m1 then
|
Chris@251
|
559 if isSparse? m2 then
|
Chris@251
|
560 sparseProduct size m1 m2
|
Chris@251
|
561 else
|
Chris@251
|
562 sparseProductLeft size m1 m2
|
Chris@251
|
563 fi
|
Chris@249
|
564 elif isSparse? m2 then
|
Chris@249
|
565 sparseProductRight size m1 m2
|
Chris@249
|
566 else
|
Chris@249
|
567 denseProduct size m1 m2
|
Chris@249
|
568 fi;
|
Chris@98
|
569 fi;
|
Chris@98
|
570
|
Chris@178
|
571 concatAgainstGrain tagger getter counter mm =
|
Chris@208
|
572 (n = counter (size (head mm));
|
Chris@208
|
573 tagger (array
|
Chris@177
|
574 (map do i:
|
Chris@214
|
575 vec.concat (map (getter i) mm)
|
Chris@208
|
576 done [0..n-1])));
|
Chris@177
|
577
|
Chris@178
|
578 concatWithGrain tagger getter counter mm =
|
Chris@208
|
579 tagger (array
|
Chris@177
|
580 (concat
|
Chris@177
|
581 (map do m:
|
Chris@208
|
582 n = counter (size m);
|
Chris@208
|
583 map do i: getter i m done [0..n-1]
|
Chris@208
|
584 done mm)));
|
Chris@177
|
585
|
Chris@259
|
586 sparseConcat direction first mm =
|
Chris@259
|
587 (dimension d f = if direction == d then sum (map f mm) else f first fi;
|
Chris@259
|
588 rows = dimension (Vertical ()) height;
|
Chris@259
|
589 columns = dimension (Horizontal ()) width;
|
Chris@259
|
590 entries ioff joff ui uj mm =
|
Chris@259
|
591 case mm of
|
Chris@259
|
592 m::rest:
|
Chris@259
|
593 (map do { i, j, v }: { i = i + ioff, j = j + joff, v }
|
Chris@259
|
594 done (enumerate m)) ++
|
Chris@259
|
595 (entries
|
Chris@259
|
596 (ioff + ui * height m)
|
Chris@259
|
597 (joff + uj * width m)
|
Chris@259
|
598 ui uj rest);
|
Chris@259
|
599 _: []
|
Chris@259
|
600 esac;
|
Chris@261
|
601 makeSparse (typeOf first) { rows, columns }
|
Chris@259
|
602 if direction == Vertical () then entries 0 0 1 0 mm
|
Chris@259
|
603 else entries 0 0 0 1 mm fi);
|
Chris@259
|
604
|
Chris@178
|
605 checkDimensionsFor direction first mm =
|
Chris@178
|
606 (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
|
Chris@208
|
607 n = counter (size first);
|
Chris@208
|
608 if not (all id (map do m: counter (size m) == n done mm)) then
|
Chris@208
|
609 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
|
Chris@178
|
610 fi);
|
Chris@178
|
611
|
Chris@275
|
612 concat' direction mm = //!!! doc: storage order is taken from first matrix in sequence
|
Chris@259
|
613 case length mm of
|
Chris@259
|
614 0: zeroSizeMatrix ();
|
Chris@259
|
615 1: head mm;
|
Chris@259
|
616 _:
|
Chris@259
|
617 first = head mm;
|
Chris@178
|
618 checkDimensionsFor direction first mm;
|
Chris@259
|
619 if all isSparse? mm then
|
Chris@259
|
620 sparseConcat direction first mm
|
Chris@259
|
621 else
|
Chris@259
|
622 row = isRowMajor? first;
|
Chris@259
|
623 // horizontal, row-major: against grain with rows
|
Chris@259
|
624 // horizontal, col-major: with grain with cols
|
Chris@259
|
625 // vertical, row-major: with grain with rows
|
Chris@259
|
626 // vertical, col-major: against grain with cols
|
Chris@259
|
627 case direction of
|
Chris@259
|
628 Horizontal ():
|
Chris@259
|
629 if row then concatAgainstGrain DenseRows getRow (.rows) mm;
|
Chris@259
|
630 else concatWithGrain DenseCols getColumn (.columns) mm;
|
Chris@259
|
631 fi;
|
Chris@259
|
632 Vertical ():
|
Chris@259
|
633 if row then concatWithGrain DenseRows getRow (.rows) mm;
|
Chris@259
|
634 else concatAgainstGrain DenseCols getColumn (.columns) mm;
|
Chris@259
|
635 fi;
|
Chris@259
|
636 esac;
|
Chris@259
|
637 fi;
|
Chris@190
|
638 esac;
|
Chris@177
|
639
|
Chris@275
|
640 //!!! doc this filter -- zero-size elts are ignored
|
Chris@275
|
641 concat direction mm =
|
Chris@275
|
642 concat' direction (filter do mat: not (isZeroSize? mat) done mm);
|
Chris@275
|
643
|
Chris@277
|
644 //!!! next two v. clumsy
|
Chris@277
|
645
|
Chris@260
|
646 //!!! doc note: argument order chosen for consistency with std module slice
|
Chris@277
|
647 //!!! NB always returns dense matrix, should have sparse version
|
Chris@260
|
648 rowSlice m start end = //!!! doc: storage order same as input
|
Chris@277
|
649 if start < 0 then rowSlice m 0 end
|
Chris@277
|
650 elif start > height m then rowSlice m (height m) end
|
Chris@277
|
651 else
|
Chris@277
|
652 if end < start then rowSlice m start start
|
Chris@277
|
653 elif end > height m then rowSlice m start (height m)
|
Chris@277
|
654 else
|
Chris@277
|
655 if isRowMajor? m then
|
Chris@277
|
656 DenseRows (array (map ((flip getRow) m) [start .. end - 1]))
|
Chris@277
|
657 else
|
Chris@277
|
658 DenseCols (array (map do v: vec.slice v start end done (asColumns m)))
|
Chris@277
|
659 fi;
|
Chris@277
|
660 fi;
|
Chris@187
|
661 fi;
|
Chris@187
|
662
|
Chris@260
|
663 //!!! doc note: argument order chosen for consistency with std module slice
|
Chris@277
|
664 //!!! NB always returns dense matrix, should have sparse version
|
Chris@260
|
665 columnSlice m start end = //!!! doc: storage order same as input
|
Chris@277
|
666 if start < 0 then columnSlice m 0 end
|
Chris@277
|
667 elif start > width m then columnSlice m (width m) end
|
Chris@277
|
668 else
|
Chris@277
|
669 if end < start then columnSlice m start start
|
Chris@277
|
670 elif end > width m then columnSlice m start (width m)
|
Chris@277
|
671 else
|
Chris@277
|
672 if not isRowMajor? m then
|
Chris@277
|
673 DenseCols (array (map ((flip getColumn) m) [start .. end - 1]))
|
Chris@277
|
674 else
|
Chris@277
|
675 DenseRows (array (map do v: vec.slice v start end done (asRows m)))
|
Chris@277
|
676 fi;
|
Chris@277
|
677 fi;
|
Chris@187
|
678 fi;
|
Chris@187
|
679
|
Chris@201
|
680 resizedTo newsize m =
|
Chris@208
|
681 (if newsize == (size m) then
|
Chris@201
|
682 m
|
Chris@261
|
683 elif isSparse? m then
|
Chris@261
|
684 // don't call makeSparse directly: want to discard
|
Chris@261
|
685 // out-of-range cells
|
Chris@261
|
686 newSparseMatrix (typeOf m) newsize (enumerateSparse m)
|
Chris@208
|
687 elif (height m) == 0 or (width m) == 0 then
|
Chris@202
|
688 zeroMatrixWithTypeOf m newsize;
|
Chris@201
|
689 else
|
Chris@208
|
690 growrows = newsize.rows - (height m);
|
Chris@208
|
691 growcols = newsize.columns - (width m);
|
Chris@208
|
692 rowm = isRowMajor? m;
|
Chris@201
|
693 resizedTo newsize
|
Chris@201
|
694 if rowm and growrows < 0 then
|
Chris@260
|
695 rowSlice m 0 newsize.rows
|
Chris@201
|
696 elif (not rowm) and growcols < 0 then
|
Chris@260
|
697 columnSlice m 0 newsize.columns
|
Chris@201
|
698 elif growrows < 0 then
|
Chris@260
|
699 rowSlice m 0 newsize.rows
|
Chris@201
|
700 elif growcols < 0 then
|
Chris@260
|
701 columnSlice m 0 newsize.columns
|
Chris@201
|
702 else
|
Chris@201
|
703 if growrows > 0 then
|
Chris@201
|
704 concat (Vertical ())
|
Chris@208
|
705 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
|
Chris@201
|
706 else
|
Chris@201
|
707 concat (Horizontal ())
|
Chris@208
|
708 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
|
Chris@201
|
709 fi
|
Chris@201
|
710 fi
|
Chris@202
|
711 fi);
|
Chris@201
|
712
|
Chris@5
|
713 {
|
Chris@208
|
714 size,
|
Chris@208
|
715 width,
|
Chris@208
|
716 height,
|
Chris@246
|
717 density,
|
Chris@249
|
718 nonZeroValues,
|
Chris@260
|
719 at = at',
|
Chris@208
|
720 getColumn,
|
Chris@208
|
721 getRow,
|
Chris@208
|
722 isRowMajor?,
|
Chris@234
|
723 isSparse?,
|
Chris@208
|
724 generate,
|
Chris@208
|
725 constMatrix,
|
Chris@208
|
726 randomMatrix,
|
Chris@208
|
727 zeroMatrix,
|
Chris@208
|
728 identityMatrix,
|
Chris@208
|
729 zeroSizeMatrix,
|
Chris@275
|
730 isZeroSize?,
|
Chris@208
|
731 equal,
|
Chris@226
|
732 equalUnder,
|
Chris@208
|
733 transposed,
|
Chris@208
|
734 flipped,
|
Chris@208
|
735 toRowMajor,
|
Chris@208
|
736 toColumnMajor,
|
Chris@238
|
737 toSparse,
|
Chris@238
|
738 toDense,
|
Chris@208
|
739 scaled,
|
Chris@208
|
740 resizedTo,
|
Chris@208
|
741 asRows,
|
Chris@208
|
742 asColumns,
|
Chris@208
|
743 sum = sum',
|
Chris@229
|
744 difference,
|
Chris@229
|
745 abs = abs',
|
Chris@275
|
746 filter = filter',
|
Chris@208
|
747 product,
|
Chris@208
|
748 concat,
|
Chris@208
|
749 rowSlice,
|
Chris@208
|
750 columnSlice,
|
Chris@208
|
751 newMatrix,
|
Chris@208
|
752 newRowVector,
|
Chris@208
|
753 newColumnVector,
|
Chris@261
|
754 newSparseMatrix,
|
Chris@255
|
755 enumerate
|
Chris@208
|
756 }
|
Chris@208
|
757 as
|
Chris@208
|
758 {
|
Chris@208
|
759 //!!! check whether these are right to be .selector rather than just selector
|
Chris@208
|
760
|
Chris@208
|
761 size is matrix -> { .rows is number, .columns is number },
|
Chris@208
|
762 width is matrix -> number,
|
Chris@208
|
763 height is matrix -> number,
|
Chris@246
|
764 density is matrix -> number,
|
Chris@249
|
765 nonZeroValues is matrix -> number,
|
Chris@260
|
766 at is matrix -> number -> number -> number,
|
Chris@214
|
767 getColumn is number -> matrix -> vector,
|
Chris@214
|
768 getRow is number -> matrix -> vector,
|
Chris@208
|
769 isRowMajor? is matrix -> boolean,
|
Chris@234
|
770 isSparse? is matrix -> boolean,
|
Chris@195
|
771 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
772 constMatrix is number -> { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
773 randomMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
774 zeroMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
775 identityMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
776 zeroSizeMatrix is () -> matrix,
|
Chris@275
|
777 isZeroSize? is matrix -> boolean,
|
Chris@195
|
778 equal is matrix -> matrix -> boolean,
|
Chris@226
|
779 equalUnder is (number -> number -> boolean) -> matrix -> matrix -> boolean,
|
Chris@195
|
780 transposed is matrix -> matrix,
|
Chris@195
|
781 flipped is matrix -> matrix,
|
Chris@195
|
782 toRowMajor is matrix -> matrix,
|
Chris@195
|
783 toColumnMajor is matrix -> matrix,
|
Chris@243
|
784 toSparse is matrix -> matrix,
|
Chris@238
|
785 toDense is matrix -> matrix,
|
Chris@195
|
786 scaled is number -> matrix -> matrix,
|
Chris@243
|
787 thresholded is number -> matrix -> matrix,
|
Chris@195
|
788 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
|
Chris@214
|
789 asRows is matrix -> list<vector>,
|
Chris@214
|
790 asColumns is matrix -> list<vector>,
|
Chris@208
|
791 sum is matrix -> matrix -> matrix,
|
Chris@229
|
792 difference is matrix -> matrix -> matrix,
|
Chris@229
|
793 abs is matrix -> matrix,
|
Chris@258
|
794 filter is (number -> boolean) -> matrix -> matrix,
|
Chris@195
|
795 product is matrix -> matrix -> matrix,
|
Chris@195
|
796 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
|
Chris@260
|
797 rowSlice is matrix -> number -> number -> matrix,
|
Chris@260
|
798 columnSlice is matrix -> number -> number -> matrix,
|
Chris@214
|
799 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix,
|
Chris@214
|
800 newRowVector is vector -> matrix,
|
Chris@214
|
801 newColumnVector is vector -> matrix,
|
Chris@255
|
802 newSparseMatrix is (ColumnMajor () | RowMajor ()) -> { .rows is number, .columns is number } -> list<{ .i is number, .j is number, .v is number }> -> matrix,
|
Chris@255
|
803 enumerate is matrix -> list<{ .i is number, .j is number, .v is number }>
|
Chris@5
|
804 }
|
Chris@5
|
805
|