Chris@5
|
1
|
Chris@94
|
2 module yetilab.matrix.matrix;
|
Chris@94
|
3
|
Chris@214
|
4 // A matrix is an array of vectors.
|
Chris@94
|
5
|
Chris@101
|
6 // A matrix can be stored in either column-major (the default) or
|
Chris@101
|
7 // row-major format. Storage order is an efficiency concern only:
|
Chris@101
|
8 // every API function operating on matrix objects will return the same
|
Chris@101
|
9 // result regardless of storage order. (The transpose function just
|
Chris@101
|
10 // switches the row/column order without moving the elements.)
|
Chris@18
|
11
|
Chris@222
|
12 vec = load yetilab.vector.vector;
|
Chris@222
|
13 bf = load yetilab.vector.blockfuncs;
|
Chris@9
|
14
|
Chris@222
|
15 load yetilab.vector.vectortype;
|
Chris@195
|
16 load yetilab.matrix.matrixtype;
|
Chris@195
|
17
|
Chris@208
|
18 size m =
|
Chris@208
|
19 case m of
|
Chris@234
|
20 DenseRows r:
|
Chris@208
|
21 major = length r;
|
Chris@208
|
22 {
|
Chris@208
|
23 rows = major,
|
Chris@214
|
24 columns = if major > 0 then vec.length r[0] else 0 fi,
|
Chris@208
|
25 };
|
Chris@234
|
26 DenseCols c:
|
Chris@208
|
27 major = length c;
|
Chris@208
|
28 {
|
Chris@214
|
29 rows = if major > 0 then vec.length c[0] else 0 fi,
|
Chris@208
|
30 columns = major,
|
Chris@208
|
31 };
|
Chris@237
|
32 SparseCSR { values, indices, pointers, extent }:
|
Chris@234
|
33 {
|
Chris@236
|
34 rows = (length pointers) - 1,
|
Chris@237
|
35 columns = extent
|
Chris@234
|
36 };
|
Chris@237
|
37 SparseCSC { values, indices, pointers, extent }:
|
Chris@234
|
38 {
|
Chris@237
|
39 rows = extent,
|
Chris@236
|
40 columns = (length pointers) - 1
|
Chris@234
|
41 };
|
Chris@208
|
42 esac;
|
Chris@208
|
43
|
Chris@208
|
44 width m = (size m).columns;
|
Chris@208
|
45 height m = (size m).rows;
|
Chris@208
|
46
|
Chris@249
|
47 nonZeroValues m =
|
Chris@249
|
48 (nz d =
|
Chris@242
|
49 sum
|
Chris@242
|
50 (map do v:
|
Chris@242
|
51 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
|
Chris@242
|
52 done d);
|
Chris@242
|
53 case m of
|
Chris@249
|
54 DenseRows d: nz d;
|
Chris@249
|
55 DenseCols d: nz d;
|
Chris@249
|
56 SparseCSR d: vec.length d.values;
|
Chris@249
|
57 SparseCSC d: vec.length d.values;
|
Chris@242
|
58 esac);
|
Chris@242
|
59
|
Chris@249
|
60 density m =
|
Chris@249
|
61 ({ rows, columns } = size m;
|
Chris@249
|
62 cells = rows * columns;
|
Chris@249
|
63 (nonZeroValues m) / cells);
|
Chris@249
|
64
|
Chris@236
|
65 sparseSlice n d =
|
Chris@236
|
66 (start = d.pointers[n];
|
Chris@236
|
67 end = d.pointers[n+1];
|
Chris@236
|
68 {
|
Chris@236
|
69 values = vec.slice d.values start end,
|
Chris@239
|
70 indices = slice d.indices start end,
|
Chris@236
|
71 });
|
Chris@236
|
72
|
Chris@252
|
73 nonEmptySlices d =
|
Chris@252
|
74 (ne = array [];
|
Chris@252
|
75 for [0..length d.pointers - 2] do i:
|
Chris@252
|
76 if d.pointers[i] != d.pointers[i+1] then
|
Chris@252
|
77 push ne i
|
Chris@252
|
78 fi
|
Chris@252
|
79 done;
|
Chris@252
|
80 ne);
|
Chris@252
|
81
|
Chris@236
|
82 fromSlice n m d =
|
Chris@236
|
83 (slice = sparseSlice n d;
|
Chris@236
|
84 var v = 0;
|
Chris@236
|
85 for [0..length slice.indices - 1] do i:
|
Chris@236
|
86 if slice.indices[i] == m then
|
Chris@260
|
87 v := vec.at slice.values i;
|
Chris@236
|
88 fi
|
Chris@236
|
89 done;
|
Chris@236
|
90 v);
|
Chris@236
|
91
|
Chris@236
|
92 filledSlice n d =
|
Chris@236
|
93 (slice = sparseSlice n d;
|
Chris@236
|
94 dslice = new double[d.extent];
|
Chris@239
|
95 for [0..length slice.indices - 1] do i:
|
Chris@260
|
96 dslice[slice.indices[i]] := vec.at slice.values i;
|
Chris@239
|
97 done;
|
Chris@236
|
98 vec.vector dslice);
|
Chris@236
|
99
|
Chris@260
|
100 at' m row col =
|
Chris@208
|
101 case m of
|
Chris@260
|
102 DenseRows rows: r = rows[row]; vec.at r col;
|
Chris@260
|
103 DenseCols cols: c = cols[col]; vec.at c row;
|
Chris@236
|
104 SparseCSR data: fromSlice row col data;
|
Chris@236
|
105 SparseCSC data: fromSlice col row data;
|
Chris@208
|
106 esac;
|
Chris@208
|
107
|
Chris@208
|
108 getColumn j m =
|
Chris@208
|
109 case m of
|
Chris@234
|
110 DenseCols cols: cols[j];
|
Chris@236
|
111 SparseCSC data: filledSlice j data;
|
Chris@260
|
112 _: vec.fromList (map do i: at' m i j done [0..height m - 1]);
|
Chris@208
|
113 esac;
|
Chris@208
|
114
|
Chris@208
|
115 getRow i m =
|
Chris@208
|
116 case m of
|
Chris@234
|
117 DenseRows rows: rows[i];
|
Chris@236
|
118 SparseCSR data: filledSlice i data;
|
Chris@260
|
119 _: vec.fromList (map do j: at' m i j done [0..width m - 1]);
|
Chris@208
|
120 esac;
|
Chris@208
|
121
|
Chris@257
|
122 asRows m =
|
Chris@257
|
123 map do i: getRow i m done [0 .. (height m) - 1];
|
Chris@257
|
124
|
Chris@257
|
125 asColumns m =
|
Chris@257
|
126 map do i: getColumn i m done [0 .. (width m) - 1];
|
Chris@257
|
127
|
Chris@208
|
128 isRowMajor? m =
|
Chris@208
|
129 case m of
|
Chris@234
|
130 DenseRows _: true;
|
Chris@234
|
131 DenseCols _: false;
|
Chris@234
|
132 SparseCSR _: true;
|
Chris@234
|
133 SparseCSC _: false;
|
Chris@234
|
134 esac;
|
Chris@234
|
135
|
Chris@234
|
136 isSparse? m =
|
Chris@234
|
137 case m of
|
Chris@234
|
138 DenseRows _: false;
|
Chris@234
|
139 DenseCols _: false;
|
Chris@234
|
140 SparseCSR _: true;
|
Chris@234
|
141 SparseCSC _: true;
|
Chris@208
|
142 esac;
|
Chris@94
|
143
|
Chris@261
|
144 typeOf m =
|
Chris@261
|
145 if isRowMajor? m then RowMajor ()
|
Chris@261
|
146 else ColumnMajor ()
|
Chris@261
|
147 fi;
|
Chris@261
|
148
|
Chris@261
|
149 flippedTypeOf m =
|
Chris@261
|
150 if isRowMajor? m then ColumnMajor ()
|
Chris@261
|
151 else RowMajor ()
|
Chris@261
|
152 fi;
|
Chris@261
|
153
|
Chris@244
|
154 newColumnMajorStorage { rows, columns } =
|
Chris@97
|
155 if rows < 1 then array []
|
Chris@214
|
156 else array (map \(vec.zeros rows) [1..columns])
|
Chris@97
|
157 fi;
|
Chris@94
|
158
|
Chris@98
|
159 zeroMatrix { rows, columns } =
|
Chris@244
|
160 DenseCols (newColumnMajorStorage { rows, columns });
|
Chris@201
|
161
|
Chris@201
|
162 zeroMatrixWithTypeOf m { rows, columns } =
|
Chris@208
|
163 if isRowMajor? m then
|
Chris@244
|
164 DenseRows (newColumnMajorStorage { rows = columns, columns = rows });
|
Chris@201
|
165 else
|
Chris@244
|
166 DenseCols (newColumnMajorStorage { rows, columns });
|
Chris@201
|
167 fi;
|
Chris@5
|
168
|
Chris@214
|
169 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
|
Chris@214
|
170
|
Chris@98
|
171 generate f { rows, columns } =
|
Chris@214
|
172 if rows < 1 or columns < 1 then zeroSizeMatrix ()
|
Chris@214
|
173 else
|
Chris@214
|
174 m = array (map \(new double[rows]) [1..columns]);
|
Chris@214
|
175 for [0..columns-1] do col:
|
Chris@214
|
176 for [0..rows-1] do row:
|
Chris@214
|
177 m[col][row] := f row col;
|
Chris@214
|
178 done;
|
Chris@5
|
179 done;
|
Chris@234
|
180 DenseCols (array (map vec.vector m))
|
Chris@214
|
181 fi;
|
Chris@5
|
182
|
Chris@255
|
183 swapij =
|
Chris@255
|
184 map do { i, j, v }: { i = j, j = i, v } done;
|
Chris@255
|
185
|
Chris@257
|
186 //!!! should use { row = , column = , value = } instead of i, j, v?
|
Chris@235
|
187 enumerateSparse m =
|
Chris@240
|
188 (enumerate { values, indices, pointers } =
|
Chris@240
|
189 concat
|
Chris@240
|
190 (map do i:
|
Chris@240
|
191 start = pointers[i];
|
Chris@240
|
192 end = pointers[i+1];
|
Chris@240
|
193 map2 do j v: { i, j, v } done
|
Chris@240
|
194 (slice indices start end)
|
Chris@240
|
195 (vec.list (vec.slice values start end))
|
Chris@240
|
196 done [0..length pointers - 2]);
|
Chris@235
|
197 case m of
|
Chris@255
|
198 SparseCSC d: swapij (enumerate d);
|
Chris@255
|
199 SparseCSR d: enumerate d;
|
Chris@237
|
200 _: [];
|
Chris@235
|
201 esac);
|
Chris@235
|
202
|
Chris@255
|
203 enumerateDense m =
|
Chris@255
|
204 (enumerate d =
|
Chris@255
|
205 concat
|
Chris@255
|
206 (map do i:
|
Chris@255
|
207 vv = d[i];
|
Chris@255
|
208 map2 do j v: { i, j, v } done
|
Chris@255
|
209 [0..vec.length vv - 1]
|
Chris@255
|
210 (vec.list vv);
|
Chris@255
|
211 done [0..length d - 1]);
|
Chris@255
|
212 case m of
|
Chris@255
|
213 DenseCols c: swapij (enumerate c);
|
Chris@255
|
214 DenseRows r: enumerate r;
|
Chris@255
|
215 _: [];
|
Chris@255
|
216 esac);
|
Chris@255
|
217
|
Chris@255
|
218 enumerate m =
|
Chris@255
|
219 if isSparse? m then enumerateSparse m else enumerateDense m fi;
|
Chris@255
|
220
|
Chris@261
|
221 // Make a sparse matrix from entries whose i, j values are known to be
|
Chris@261
|
222 // within range
|
Chris@238
|
223 makeSparse type size data =
|
Chris@244
|
224 (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
|
Chris@238
|
225 ordered =
|
Chris@238
|
226 sortBy do a b:
|
Chris@238
|
227 if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
|
Chris@238
|
228 done
|
Chris@238
|
229 (map
|
Chris@238
|
230 if isRow then
|
Chris@238
|
231 do { i, j, v }: { maj = i, min = j, v } done;
|
Chris@238
|
232 else
|
Chris@238
|
233 do { i, j, v }: { maj = j, min = i, v } done;
|
Chris@238
|
234 fi
|
Chris@255
|
235 (filter do d: d.v != 0 done data));
|
Chris@238
|
236 tagger = if isRow then SparseCSR else SparseCSC fi;
|
Chris@238
|
237 majorSize = if isRow then size.rows else size.columns fi;
|
Chris@238
|
238 minorSize = if isRow then size.columns else size.rows fi;
|
Chris@251
|
239 pointers = array [0];
|
Chris@251
|
240 setArrayCapacity pointers (size.rows + 1);
|
Chris@251
|
241 fillPointers n i data =
|
Chris@251
|
242 if n < majorSize then
|
Chris@251
|
243 case data of
|
Chris@238
|
244 d::rest:
|
Chris@251
|
245 (for [n..d-1] \(push pointers i);
|
Chris@251
|
246 fillPointers d (i+1) rest);
|
Chris@251
|
247 _:
|
Chris@251
|
248 for [n..majorSize-1] \(push pointers i);
|
Chris@238
|
249 esac;
|
Chris@238
|
250 fi;
|
Chris@251
|
251 fillPointers 0 0 (map (.maj) ordered);
|
Chris@238
|
252 tagger {
|
Chris@238
|
253 values = vec.fromList (map (.v) ordered),
|
Chris@238
|
254 indices = array (map (.min) ordered),
|
Chris@251
|
255 pointers,
|
Chris@238
|
256 extent = minorSize,
|
Chris@238
|
257 });
|
Chris@238
|
258
|
Chris@261
|
259 // Make a sparse matrix from entries that may contain out-of-range
|
Chris@261
|
260 // cells which need to be filtered out. This is the public API for
|
Chris@261
|
261 // makeSparse and is also used to discard out-of-range cells from
|
Chris@261
|
262 // resizedTo.
|
Chris@261
|
263 newSparseMatrix type size data =
|
Chris@261
|
264 makeSparse type size
|
Chris@261
|
265 (filter
|
Chris@261
|
266 do { i, j, v }:
|
Chris@261
|
267 i == int i and i >= 0 and i < size.rows and
|
Chris@261
|
268 j == int j and j >= 0 and j < size.columns
|
Chris@261
|
269 done data);
|
Chris@261
|
270
|
Chris@243
|
271 toSparse m =
|
Chris@238
|
272 if isSparse? m then m
|
Chris@238
|
273 else
|
Chris@261
|
274 makeSparse (typeOf m) (size m) (enumerateDense m);
|
Chris@238
|
275 fi;
|
Chris@238
|
276
|
Chris@238
|
277 toDense m =
|
Chris@238
|
278 if not (isSparse? m) then m
|
Chris@238
|
279 elif isRowMajor? m then
|
Chris@238
|
280 DenseRows (array (map do row: getRow row m done [0..height m - 1]));
|
Chris@238
|
281 else
|
Chris@238
|
282 DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
|
Chris@238
|
283 fi;
|
Chris@235
|
284
|
Chris@20
|
285 constMatrix n = generate do row col: n done;
|
Chris@20
|
286 randomMatrix = generate do row col: Math#random() done;
|
Chris@5
|
287 identityMatrix = constMatrix 1;
|
Chris@5
|
288
|
Chris@100
|
289 transposed m =
|
Chris@208
|
290 case m of
|
Chris@234
|
291 DenseRows d: DenseCols d;
|
Chris@234
|
292 DenseCols d: DenseRows d;
|
Chris@236
|
293 SparseCSR d: SparseCSC d;
|
Chris@236
|
294 SparseCSC d: SparseCSR d;
|
Chris@208
|
295 esac;
|
Chris@100
|
296
|
Chris@100
|
297 flipped m =
|
Chris@235
|
298 if isSparse? m then
|
Chris@261
|
299 makeSparse (flippedTypeOf m) (size m) (enumerateSparse m)
|
Chris@100
|
300 else
|
Chris@235
|
301 if isRowMajor? m then
|
Chris@260
|
302 generate do row col: at' m row col done (size m);
|
Chris@235
|
303 else
|
Chris@235
|
304 transposed
|
Chris@260
|
305 (generate do row col: at' m col row done
|
Chris@235
|
306 { rows = (width m), columns = (height m) });
|
Chris@235
|
307 fi
|
Chris@100
|
308 fi;
|
Chris@100
|
309
|
Chris@161
|
310 toRowMajor m =
|
Chris@208
|
311 if isRowMajor? m then m else flipped m fi;
|
Chris@161
|
312
|
Chris@161
|
313 toColumnMajor m =
|
Chris@208
|
314 if not isRowMajor? m then m else flipped m fi;
|
Chris@161
|
315
|
Chris@238
|
316 equal'' comparator vecComparator m1 m2 =
|
Chris@238
|
317 // Prerequisite: m1 and m2 have same sparse-p and storage order
|
Chris@241
|
318 (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
|
Chris@238
|
319 compareSparse d1 d2 =
|
Chris@238
|
320 d1.extent == d2.extent and
|
Chris@238
|
321 vecComparator d1.values d2.values and
|
Chris@241
|
322 d1.indices == d2.indices and
|
Chris@241
|
323 d1.pointers == d2.pointers;
|
Chris@238
|
324 case m1 of
|
Chris@238
|
325 DenseRows d1:
|
Chris@238
|
326 case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
|
Chris@238
|
327 DenseCols d1:
|
Chris@238
|
328 case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
|
Chris@238
|
329 SparseCSR d1:
|
Chris@238
|
330 case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
|
Chris@238
|
331 SparseCSC d1:
|
Chris@238
|
332 case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
|
Chris@238
|
333 esac);
|
Chris@238
|
334
|
Chris@238
|
335 equal' comparator vecComparator m1 m2 =
|
Chris@226
|
336 if size m1 != size m2 then
|
Chris@226
|
337 false
|
Chris@226
|
338 elif isRowMajor? m1 != isRowMajor? m2 then
|
Chris@238
|
339 equal' comparator vecComparator (flipped m1) m2;
|
Chris@238
|
340 elif isSparse? m1 != isSparse? m2 then
|
Chris@238
|
341 if isSparse? m1 then
|
Chris@243
|
342 equal' comparator vecComparator m1 (toSparse m2)
|
Chris@238
|
343 else
|
Chris@243
|
344 equal' comparator vecComparator (toSparse m1) m2
|
Chris@238
|
345 fi
|
Chris@100
|
346 else
|
Chris@238
|
347 equal'' comparator vecComparator m1 m2
|
Chris@100
|
348 fi;
|
Chris@97
|
349
|
Chris@228
|
350 // Compare matrices using the given comparator for individual cells.
|
Chris@228
|
351 // Note that matrices with different storage order but the same
|
Chris@228
|
352 // contents are equal, although comparing them is slow.
|
Chris@249
|
353 //!!! Document the fact that sparse matrices can only be equal if they
|
Chris@249
|
354 // have the same set of non-zero cells (regardless of comparator used)
|
Chris@228
|
355 equalUnder comparator =
|
Chris@238
|
356 equal' comparator (vec.equalUnder comparator);
|
Chris@228
|
357
|
Chris@228
|
358 equal =
|
Chris@238
|
359 equal' (==) vec.equal;
|
Chris@226
|
360
|
Chris@163
|
361 newMatrix type data = //!!! NB does not copy data
|
Chris@238
|
362 (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
|
Chris@214
|
363 if empty? data or vec.empty? (head data)
|
Chris@201
|
364 then zeroSizeMatrix ()
|
Chris@208
|
365 else tagger (array data)
|
Chris@96
|
366 fi);
|
Chris@96
|
367
|
Chris@163
|
368 newRowVector data = //!!! NB does not copy data
|
Chris@234
|
369 DenseRows (array [data]);
|
Chris@96
|
370
|
Chris@163
|
371 newColumnVector data = //!!! NB does not copy data
|
Chris@234
|
372 DenseCols (array [data]);
|
Chris@8
|
373
|
Chris@257
|
374 denseLinearOp op m1 m2 =
|
Chris@257
|
375 if isRowMajor? m1 then
|
Chris@261
|
376 newMatrix (typeOf m1)
|
Chris@257
|
377 (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
|
Chris@257
|
378 else
|
Chris@261
|
379 newMatrix (typeOf m1)
|
Chris@257
|
380 (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
|
Chris@257
|
381 fi;
|
Chris@257
|
382
|
Chris@257
|
383 sparseSumOrDifference op m1 m2 =
|
Chris@257
|
384 (h = [:];
|
Chris@257
|
385 for (enumerate m1) do { i, j, v }:
|
Chris@257
|
386 if not (i in h) then h[i] := [:] fi;
|
Chris@257
|
387 h[i][j] := v;
|
Chris@257
|
388 done;
|
Chris@257
|
389 for (enumerate m2) do { i, j, v }:
|
Chris@257
|
390 if not (i in h) then h[i] := [:] fi;
|
Chris@257
|
391 if j in h[i] then h[i][j] := op h[i][j] v;
|
Chris@257
|
392 else h[i][j] := op 0 v;
|
Chris@257
|
393 fi;
|
Chris@257
|
394 done;
|
Chris@257
|
395 entries = concat
|
Chris@257
|
396 (map do i:
|
Chris@257
|
397 kk = keys h[i];
|
Chris@257
|
398 map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
|
Chris@257
|
399 done (keys h));
|
Chris@261
|
400 makeSparse (typeOf m1) (size m1) entries);
|
Chris@257
|
401
|
Chris@98
|
402 sum' m1 m2 =
|
Chris@208
|
403 if (size m1) != (size m2)
|
Chris@208
|
404 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
|
Chris@257
|
405 elif isSparse? m1 and isSparse? m2 then
|
Chris@257
|
406 sparseSumOrDifference (+) m1 m2;
|
Chris@98
|
407 else
|
Chris@257
|
408 denseLinearOp bf.add m1 m2;
|
Chris@98
|
409 fi;
|
Chris@98
|
410
|
Chris@257
|
411 difference m1 m2 =
|
Chris@229
|
412 if (size m1) != (size m2)
|
Chris@229
|
413 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
|
Chris@257
|
414 elif isSparse? m1 and isSparse? m2 then
|
Chris@257
|
415 sparseSumOrDifference (-) m1 m2;
|
Chris@229
|
416 else
|
Chris@257
|
417 denseLinearOp bf.subtract m1 m2;
|
Chris@257
|
418 fi;
|
Chris@257
|
419
|
Chris@257
|
420 scaled factor m =
|
Chris@257
|
421 if isSparse? m then
|
Chris@261
|
422 makeSparse (typeOf m) (size m)
|
Chris@261
|
423 (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
|
Chris@257
|
424 elif isRowMajor? m then
|
Chris@261
|
425 newMatrix (typeOf m) (map (bf.scaled factor) (asRows m));
|
Chris@257
|
426 else
|
Chris@261
|
427 newMatrix (typeOf m) (map (bf.scaled factor) (asColumns m));
|
Chris@229
|
428 fi;
|
Chris@229
|
429
|
Chris@229
|
430 abs' m =
|
Chris@257
|
431 if isSparse? m then
|
Chris@261
|
432 makeSparse (typeOf m) (size m)
|
Chris@261
|
433 (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
|
Chris@257
|
434 elif isRowMajor? m then
|
Chris@261
|
435 newMatrix (typeOf m) (map bf.abs (asRows m));
|
Chris@257
|
436 else
|
Chris@261
|
437 newMatrix (typeOf m) (map bf.abs (asColumns m));
|
Chris@257
|
438 fi;
|
Chris@229
|
439
|
Chris@258
|
440 filter f m =
|
Chris@258
|
441 if isSparse? m then
|
Chris@261
|
442 makeSparse (typeOf m) (size m)
|
Chris@261
|
443 (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
|
Chris@261
|
444 (enumerate m))
|
Chris@258
|
445 else
|
Chris@258
|
446 vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
|
Chris@258
|
447 if isRowMajor? m then
|
Chris@261
|
448 newMatrix (typeOf m) (map vfilter (asRows m));
|
Chris@258
|
449 else
|
Chris@261
|
450 newMatrix (typeOf m) (map vfilter (asColumns m));
|
Chris@258
|
451 fi;
|
Chris@258
|
452 fi;
|
Chris@258
|
453
|
Chris@249
|
454 sparseProductLeft size m1 m2 =
|
Chris@249
|
455 (e = enumerateSparse m1;
|
Chris@249
|
456 data = array (map \(new double[size.rows]) [1..size.columns]);
|
Chris@249
|
457 for [0..size.columns - 1] do j':
|
Chris@249
|
458 c = getColumn j' m2;
|
Chris@249
|
459 for e do { v, i, j }:
|
Chris@260
|
460 data[j'][i] := data[j'][i] + v * (vec.at c j);
|
Chris@249
|
461 done;
|
Chris@249
|
462 done;
|
Chris@249
|
463 DenseCols (array (map vec.vector (list data))));
|
Chris@249
|
464
|
Chris@249
|
465 sparseProductRight size m1 m2 =
|
Chris@249
|
466 (e = enumerateSparse m2;
|
Chris@249
|
467 data = array (map \(new double[size.columns]) [1..size.rows]);
|
Chris@249
|
468 for [0..size.rows - 1] do i':
|
Chris@249
|
469 r = getRow i' m1;
|
Chris@249
|
470 for e do { v, i, j }:
|
Chris@260
|
471 data[i'][j] := data[i'][j] + v * (vec.at r i);
|
Chris@249
|
472 done;
|
Chris@249
|
473 done;
|
Chris@249
|
474 DenseRows (array (map vec.vector (list data))));
|
Chris@249
|
475
|
Chris@251
|
476 sparseProduct size m1 m2 =
|
Chris@251
|
477 case m2 of
|
Chris@251
|
478 SparseCSC d:
|
Chris@251
|
479 (e = enumerateSparse m1;
|
Chris@251
|
480 entries =
|
Chris@251
|
481 (map do j':
|
Chris@251
|
482 cs = sparseSlice j' d;
|
Chris@252
|
483 hin = mapIntoHash
|
Chris@260
|
484 (at cs.indices) (vec.at cs.values)
|
Chris@252
|
485 [0..length cs.indices - 1];
|
Chris@252
|
486 hout = [:];
|
Chris@252
|
487 for e do { v, i, j }:
|
Chris@252
|
488 if j in hin then
|
Chris@252
|
489 p = v * hin[j];
|
Chris@252
|
490 hout[i] := p + (if i in hout then hout[i] else 0 fi);
|
Chris@252
|
491 fi
|
Chris@252
|
492 done;
|
Chris@252
|
493 map do i:
|
Chris@252
|
494 { i, j = j', v = hout[i] }
|
Chris@252
|
495 done (keys hout);
|
Chris@252
|
496 done (nonEmptySlices d));
|
Chris@251
|
497 makeSparse (ColumnMajor ()) size (concat entries));
|
Chris@251
|
498 SparseCSR _:
|
Chris@251
|
499 sparseProduct size m1 (flipped m2);
|
Chris@251
|
500 _: failWith "sparseProduct called for non-sparse matrices";
|
Chris@251
|
501 esac;
|
Chris@251
|
502
|
Chris@249
|
503 denseProduct size m1 m2 =
|
Chris@249
|
504 (data = array (map \(new double[size.rows]) [1..size.columns]);
|
Chris@249
|
505 for [0..size.rows - 1] do i:
|
Chris@249
|
506 row = getRow i m1;
|
Chris@249
|
507 for [0..size.columns - 1] do j:
|
Chris@249
|
508 data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
|
Chris@249
|
509 done;
|
Chris@249
|
510 done;
|
Chris@249
|
511 DenseCols (array (map vec.vector (list data))));
|
Chris@249
|
512
|
Chris@98
|
513 product m1 m2 =
|
Chris@208
|
514 if (size m1).columns != (size m2).rows
|
Chris@246
|
515 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
|
Chris@249
|
516 else
|
Chris@249
|
517 size = { rows = (size m1).rows, columns = (size m2).columns };
|
Chris@249
|
518 if isSparse? m1 then
|
Chris@251
|
519 if isSparse? m2 then
|
Chris@251
|
520 sparseProduct size m1 m2
|
Chris@251
|
521 else
|
Chris@251
|
522 sparseProductLeft size m1 m2
|
Chris@251
|
523 fi
|
Chris@249
|
524 elif isSparse? m2 then
|
Chris@249
|
525 sparseProductRight size m1 m2
|
Chris@249
|
526 else
|
Chris@249
|
527 denseProduct size m1 m2
|
Chris@249
|
528 fi;
|
Chris@98
|
529 fi;
|
Chris@98
|
530
|
Chris@178
|
531 concatAgainstGrain tagger getter counter mm =
|
Chris@208
|
532 (n = counter (size (head mm));
|
Chris@208
|
533 tagger (array
|
Chris@177
|
534 (map do i:
|
Chris@214
|
535 vec.concat (map (getter i) mm)
|
Chris@208
|
536 done [0..n-1])));
|
Chris@177
|
537
|
Chris@178
|
538 concatWithGrain tagger getter counter mm =
|
Chris@208
|
539 tagger (array
|
Chris@177
|
540 (concat
|
Chris@177
|
541 (map do m:
|
Chris@208
|
542 n = counter (size m);
|
Chris@208
|
543 map do i: getter i m done [0..n-1]
|
Chris@208
|
544 done mm)));
|
Chris@177
|
545
|
Chris@259
|
546 sparseConcat direction first mm =
|
Chris@259
|
547 (dimension d f = if direction == d then sum (map f mm) else f first fi;
|
Chris@259
|
548 rows = dimension (Vertical ()) height;
|
Chris@259
|
549 columns = dimension (Horizontal ()) width;
|
Chris@259
|
550 entries ioff joff ui uj mm =
|
Chris@259
|
551 case mm of
|
Chris@259
|
552 m::rest:
|
Chris@259
|
553 (map do { i, j, v }: { i = i + ioff, j = j + joff, v }
|
Chris@259
|
554 done (enumerate m)) ++
|
Chris@259
|
555 (entries
|
Chris@259
|
556 (ioff + ui * height m)
|
Chris@259
|
557 (joff + uj * width m)
|
Chris@259
|
558 ui uj rest);
|
Chris@259
|
559 _: []
|
Chris@259
|
560 esac;
|
Chris@261
|
561 makeSparse (typeOf first) { rows, columns }
|
Chris@259
|
562 if direction == Vertical () then entries 0 0 1 0 mm
|
Chris@259
|
563 else entries 0 0 0 1 mm fi);
|
Chris@259
|
564
|
Chris@178
|
565 checkDimensionsFor direction first mm =
|
Chris@178
|
566 (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
|
Chris@208
|
567 n = counter (size first);
|
Chris@208
|
568 if not (all id (map do m: counter (size m) == n done mm)) then
|
Chris@208
|
569 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
|
Chris@178
|
570 fi);
|
Chris@178
|
571
|
Chris@187
|
572 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
|
Chris@259
|
573 case length mm of
|
Chris@259
|
574 0: zeroSizeMatrix ();
|
Chris@259
|
575 1: head mm;
|
Chris@259
|
576 _:
|
Chris@259
|
577 first = head mm;
|
Chris@178
|
578 checkDimensionsFor direction first mm;
|
Chris@259
|
579 if all isSparse? mm then
|
Chris@259
|
580 sparseConcat direction first mm
|
Chris@259
|
581 else
|
Chris@259
|
582 row = isRowMajor? first;
|
Chris@259
|
583 // horizontal, row-major: against grain with rows
|
Chris@259
|
584 // horizontal, col-major: with grain with cols
|
Chris@259
|
585 // vertical, row-major: with grain with rows
|
Chris@259
|
586 // vertical, col-major: against grain with cols
|
Chris@259
|
587 case direction of
|
Chris@259
|
588 Horizontal ():
|
Chris@259
|
589 if row then concatAgainstGrain DenseRows getRow (.rows) mm;
|
Chris@259
|
590 else concatWithGrain DenseCols getColumn (.columns) mm;
|
Chris@259
|
591 fi;
|
Chris@259
|
592 Vertical ():
|
Chris@259
|
593 if row then concatWithGrain DenseRows getRow (.rows) mm;
|
Chris@259
|
594 else concatAgainstGrain DenseCols getColumn (.columns) mm;
|
Chris@259
|
595 fi;
|
Chris@259
|
596 esac;
|
Chris@259
|
597 fi;
|
Chris@190
|
598 esac;
|
Chris@177
|
599
|
Chris@260
|
600 //!!! doc note: argument order chosen for consistency with std module slice
|
Chris@260
|
601 rowSlice m start end = //!!! doc: storage order same as input
|
Chris@208
|
602 if isRowMajor? m then
|
Chris@260
|
603 DenseRows (array (map ((flip getRow) m) [start .. end - 1]))
|
Chris@187
|
604 else
|
Chris@260
|
605 DenseCols (array (map do v: vec.slice v start end done (asColumns m)))
|
Chris@187
|
606 fi;
|
Chris@187
|
607
|
Chris@260
|
608 //!!! doc note: argument order chosen for consistency with std module slice
|
Chris@260
|
609 columnSlice m start end = //!!! doc: storage order same as input
|
Chris@208
|
610 if not isRowMajor? m then
|
Chris@260
|
611 DenseCols (array (map ((flip getColumn) m) [start .. end - 1]))
|
Chris@187
|
612 else
|
Chris@260
|
613 DenseRows (array (map do v: vec.slice v start end done (asRows m)))
|
Chris@187
|
614 fi;
|
Chris@187
|
615
|
Chris@201
|
616 resizedTo newsize m =
|
Chris@208
|
617 (if newsize == (size m) then
|
Chris@201
|
618 m
|
Chris@261
|
619 elif isSparse? m then
|
Chris@261
|
620 // don't call makeSparse directly: want to discard
|
Chris@261
|
621 // out-of-range cells
|
Chris@261
|
622 newSparseMatrix (typeOf m) newsize (enumerateSparse m)
|
Chris@208
|
623 elif (height m) == 0 or (width m) == 0 then
|
Chris@202
|
624 zeroMatrixWithTypeOf m newsize;
|
Chris@201
|
625 else
|
Chris@208
|
626 growrows = newsize.rows - (height m);
|
Chris@208
|
627 growcols = newsize.columns - (width m);
|
Chris@208
|
628 rowm = isRowMajor? m;
|
Chris@201
|
629 resizedTo newsize
|
Chris@201
|
630 if rowm and growrows < 0 then
|
Chris@260
|
631 rowSlice m 0 newsize.rows
|
Chris@201
|
632 elif (not rowm) and growcols < 0 then
|
Chris@260
|
633 columnSlice m 0 newsize.columns
|
Chris@201
|
634 elif growrows < 0 then
|
Chris@260
|
635 rowSlice m 0 newsize.rows
|
Chris@201
|
636 elif growcols < 0 then
|
Chris@260
|
637 columnSlice m 0 newsize.columns
|
Chris@201
|
638 else
|
Chris@201
|
639 if growrows > 0 then
|
Chris@201
|
640 concat (Vertical ())
|
Chris@208
|
641 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
|
Chris@201
|
642 else
|
Chris@201
|
643 concat (Horizontal ())
|
Chris@208
|
644 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
|
Chris@201
|
645 fi
|
Chris@201
|
646 fi
|
Chris@202
|
647 fi);
|
Chris@201
|
648
|
Chris@5
|
649 {
|
Chris@208
|
650 size,
|
Chris@208
|
651 width,
|
Chris@208
|
652 height,
|
Chris@246
|
653 density,
|
Chris@249
|
654 nonZeroValues,
|
Chris@260
|
655 at = at',
|
Chris@208
|
656 getColumn,
|
Chris@208
|
657 getRow,
|
Chris@208
|
658 isRowMajor?,
|
Chris@234
|
659 isSparse?,
|
Chris@208
|
660 generate,
|
Chris@208
|
661 constMatrix,
|
Chris@208
|
662 randomMatrix,
|
Chris@208
|
663 zeroMatrix,
|
Chris@208
|
664 identityMatrix,
|
Chris@208
|
665 zeroSizeMatrix,
|
Chris@208
|
666 equal,
|
Chris@226
|
667 equalUnder,
|
Chris@208
|
668 transposed,
|
Chris@208
|
669 flipped,
|
Chris@208
|
670 toRowMajor,
|
Chris@208
|
671 toColumnMajor,
|
Chris@238
|
672 toSparse,
|
Chris@238
|
673 toDense,
|
Chris@208
|
674 scaled,
|
Chris@208
|
675 resizedTo,
|
Chris@208
|
676 asRows,
|
Chris@208
|
677 asColumns,
|
Chris@208
|
678 sum = sum',
|
Chris@229
|
679 difference,
|
Chris@229
|
680 abs = abs',
|
Chris@258
|
681 filter,
|
Chris@208
|
682 product,
|
Chris@208
|
683 concat,
|
Chris@208
|
684 rowSlice,
|
Chris@208
|
685 columnSlice,
|
Chris@208
|
686 newMatrix,
|
Chris@208
|
687 newRowVector,
|
Chris@208
|
688 newColumnVector,
|
Chris@261
|
689 newSparseMatrix,
|
Chris@255
|
690 enumerate
|
Chris@208
|
691 }
|
Chris@208
|
692 as
|
Chris@208
|
693 {
|
Chris@208
|
694 //!!! check whether these are right to be .selector rather than just selector
|
Chris@208
|
695
|
Chris@208
|
696 size is matrix -> { .rows is number, .columns is number },
|
Chris@208
|
697 width is matrix -> number,
|
Chris@208
|
698 height is matrix -> number,
|
Chris@246
|
699 density is matrix -> number,
|
Chris@249
|
700 nonZeroValues is matrix -> number,
|
Chris@260
|
701 at is matrix -> number -> number -> number,
|
Chris@214
|
702 getColumn is number -> matrix -> vector,
|
Chris@214
|
703 getRow is number -> matrix -> vector,
|
Chris@208
|
704 isRowMajor? is matrix -> boolean,
|
Chris@234
|
705 isSparse? is matrix -> boolean,
|
Chris@195
|
706 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
707 constMatrix is number -> { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
708 randomMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
709 zeroMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
710 identityMatrix is { .rows is number, .columns is number } -> matrix,
|
Chris@195
|
711 zeroSizeMatrix is () -> matrix,
|
Chris@195
|
712 equal is matrix -> matrix -> boolean,
|
Chris@226
|
713 equalUnder is (number -> number -> boolean) -> matrix -> matrix -> boolean,
|
Chris@195
|
714 transposed is matrix -> matrix,
|
Chris@195
|
715 flipped is matrix -> matrix,
|
Chris@195
|
716 toRowMajor is matrix -> matrix,
|
Chris@195
|
717 toColumnMajor is matrix -> matrix,
|
Chris@243
|
718 toSparse is matrix -> matrix,
|
Chris@238
|
719 toDense is matrix -> matrix,
|
Chris@195
|
720 scaled is number -> matrix -> matrix,
|
Chris@243
|
721 thresholded is number -> matrix -> matrix,
|
Chris@195
|
722 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
|
Chris@214
|
723 asRows is matrix -> list<vector>,
|
Chris@214
|
724 asColumns is matrix -> list<vector>,
|
Chris@208
|
725 sum is matrix -> matrix -> matrix,
|
Chris@229
|
726 difference is matrix -> matrix -> matrix,
|
Chris@229
|
727 abs is matrix -> matrix,
|
Chris@258
|
728 filter is (number -> boolean) -> matrix -> matrix,
|
Chris@195
|
729 product is matrix -> matrix -> matrix,
|
Chris@195
|
730 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
|
Chris@260
|
731 rowSlice is matrix -> number -> number -> matrix,
|
Chris@260
|
732 columnSlice is matrix -> number -> number -> matrix,
|
Chris@214
|
733 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix,
|
Chris@214
|
734 newRowVector is vector -> matrix,
|
Chris@214
|
735 newColumnVector is vector -> matrix,
|
Chris@255
|
736 newSparseMatrix is (ColumnMajor () | RowMajor ()) -> { .rows is number, .columns is number } -> list<{ .i is number, .j is number, .v is number }> -> matrix,
|
Chris@255
|
737 enumerate is matrix -> list<{ .i is number, .j is number, .v is number }>
|
Chris@5
|
738 }
|
Chris@5
|
739
|