annotate yetilab/matrix/matrix.yeti @ 271:c206de7c3018

Faster sparseProductLeft. Similar code would work for other sparseProducts
author Chris Cannam
date Thu, 23 May 2013 16:12:21 +0100
parents 53ff481f1a41
children 2ebda6646c40
rev   line source
Chris@5 1
Chris@94 2 module yetilab.matrix.matrix;
Chris@94 3
Chris@214 4 // A matrix is an array of vectors.
Chris@94 5
Chris@101 6 // A matrix can be stored in either column-major (the default) or
Chris@101 7 // row-major format. Storage order is an efficiency concern only:
Chris@101 8 // every API function operating on matrix objects will return the same
Chris@101 9 // result regardless of storage order. (The transpose function just
Chris@101 10 // switches the row/column order without moving the elements.)
Chris@18 11
Chris@222 12 vec = load yetilab.vector.vector;
Chris@222 13 bf = load yetilab.vector.blockfuncs;
Chris@9 14
Chris@222 15 load yetilab.vector.vectortype;
Chris@195 16 load yetilab.matrix.matrixtype;
Chris@195 17
Chris@208 18 size m =
Chris@208 19 case m of
Chris@234 20 DenseRows r:
Chris@208 21 major = length r;
Chris@208 22 {
Chris@208 23 rows = major,
Chris@214 24 columns = if major > 0 then vec.length r[0] else 0 fi,
Chris@208 25 };
Chris@234 26 DenseCols c:
Chris@208 27 major = length c;
Chris@208 28 {
Chris@214 29 rows = if major > 0 then vec.length c[0] else 0 fi,
Chris@208 30 columns = major,
Chris@208 31 };
Chris@237 32 SparseCSR { values, indices, pointers, extent }:
Chris@234 33 {
Chris@236 34 rows = (length pointers) - 1,
Chris@237 35 columns = extent
Chris@234 36 };
Chris@237 37 SparseCSC { values, indices, pointers, extent }:
Chris@234 38 {
Chris@237 39 rows = extent,
Chris@236 40 columns = (length pointers) - 1
Chris@234 41 };
Chris@208 42 esac;
Chris@208 43
Chris@208 44 width m = (size m).columns;
Chris@208 45 height m = (size m).rows;
Chris@208 46
Chris@249 47 nonZeroValues m =
Chris@249 48 (nz d =
Chris@242 49 sum
Chris@242 50 (map do v:
Chris@242 51 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
Chris@242 52 done d);
Chris@242 53 case m of
Chris@249 54 DenseRows d: nz d;
Chris@249 55 DenseCols d: nz d;
Chris@249 56 SparseCSR d: vec.length d.values;
Chris@249 57 SparseCSC d: vec.length d.values;
Chris@242 58 esac);
Chris@242 59
Chris@249 60 density m =
Chris@249 61 ({ rows, columns } = size m;
Chris@249 62 cells = rows * columns;
Chris@249 63 (nonZeroValues m) / cells);
Chris@249 64
Chris@236 65 sparseSlice n d =
Chris@236 66 (start = d.pointers[n];
Chris@236 67 end = d.pointers[n+1];
Chris@236 68 {
Chris@236 69 values = vec.slice d.values start end,
Chris@239 70 indices = slice d.indices start end,
Chris@236 71 });
Chris@236 72
Chris@252 73 nonEmptySlices d =
Chris@252 74 (ne = array [];
Chris@252 75 for [0..length d.pointers - 2] do i:
Chris@252 76 if d.pointers[i] != d.pointers[i+1] then
Chris@252 77 push ne i
Chris@252 78 fi
Chris@252 79 done;
Chris@252 80 ne);
Chris@252 81
Chris@236 82 fromSlice n m d =
Chris@236 83 (slice = sparseSlice n d;
Chris@236 84 var v = 0;
Chris@236 85 for [0..length slice.indices - 1] do i:
Chris@236 86 if slice.indices[i] == m then
Chris@260 87 v := vec.at slice.values i;
Chris@236 88 fi
Chris@236 89 done;
Chris@236 90 v);
Chris@236 91
Chris@236 92 filledSlice n d =
Chris@236 93 (slice = sparseSlice n d;
Chris@236 94 dslice = new double[d.extent];
Chris@239 95 for [0..length slice.indices - 1] do i:
Chris@260 96 dslice[slice.indices[i]] := vec.at slice.values i;
Chris@239 97 done;
Chris@236 98 vec.vector dslice);
Chris@236 99
Chris@260 100 at' m row col =
Chris@208 101 case m of
Chris@260 102 DenseRows rows: r = rows[row]; vec.at r col;
Chris@260 103 DenseCols cols: c = cols[col]; vec.at c row;
Chris@236 104 SparseCSR data: fromSlice row col data;
Chris@236 105 SparseCSC data: fromSlice col row data;
Chris@208 106 esac;
Chris@208 107
Chris@208 108 getColumn j m =
Chris@208 109 case m of
Chris@234 110 DenseCols cols: cols[j];
Chris@236 111 SparseCSC data: filledSlice j data;
Chris@260 112 _: vec.fromList (map do i: at' m i j done [0..height m - 1]);
Chris@208 113 esac;
Chris@208 114
Chris@208 115 getRow i m =
Chris@208 116 case m of
Chris@234 117 DenseRows rows: rows[i];
Chris@236 118 SparseCSR data: filledSlice i data;
Chris@260 119 _: vec.fromList (map do j: at' m i j done [0..width m - 1]);
Chris@208 120 esac;
Chris@208 121
Chris@257 122 asRows m =
Chris@257 123 map do i: getRow i m done [0 .. (height m) - 1];
Chris@257 124
Chris@257 125 asColumns m =
Chris@257 126 map do i: getColumn i m done [0 .. (width m) - 1];
Chris@257 127
Chris@208 128 isRowMajor? m =
Chris@208 129 case m of
Chris@234 130 DenseRows _: true;
Chris@234 131 DenseCols _: false;
Chris@234 132 SparseCSR _: true;
Chris@234 133 SparseCSC _: false;
Chris@234 134 esac;
Chris@234 135
Chris@234 136 isSparse? m =
Chris@234 137 case m of
Chris@234 138 DenseRows _: false;
Chris@234 139 DenseCols _: false;
Chris@234 140 SparseCSR _: true;
Chris@234 141 SparseCSC _: true;
Chris@208 142 esac;
Chris@94 143
Chris@261 144 typeOf m =
Chris@261 145 if isRowMajor? m then RowMajor ()
Chris@261 146 else ColumnMajor ()
Chris@261 147 fi;
Chris@261 148
Chris@261 149 flippedTypeOf m =
Chris@261 150 if isRowMajor? m then ColumnMajor ()
Chris@261 151 else RowMajor ()
Chris@261 152 fi;
Chris@261 153
Chris@244 154 newColumnMajorStorage { rows, columns } =
Chris@97 155 if rows < 1 then array []
Chris@214 156 else array (map \(vec.zeros rows) [1..columns])
Chris@97 157 fi;
Chris@94 158
Chris@98 159 zeroMatrix { rows, columns } =
Chris@244 160 DenseCols (newColumnMajorStorage { rows, columns });
Chris@201 161
Chris@201 162 zeroMatrixWithTypeOf m { rows, columns } =
Chris@208 163 if isRowMajor? m then
Chris@244 164 DenseRows (newColumnMajorStorage { rows = columns, columns = rows });
Chris@201 165 else
Chris@244 166 DenseCols (newColumnMajorStorage { rows, columns });
Chris@201 167 fi;
Chris@5 168
Chris@214 169 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
Chris@214 170
Chris@98 171 generate f { rows, columns } =
Chris@214 172 if rows < 1 or columns < 1 then zeroSizeMatrix ()
Chris@214 173 else
Chris@214 174 m = array (map \(new double[rows]) [1..columns]);
Chris@214 175 for [0..columns-1] do col:
Chris@214 176 for [0..rows-1] do row:
Chris@214 177 m[col][row] := f row col;
Chris@214 178 done;
Chris@5 179 done;
Chris@234 180 DenseCols (array (map vec.vector m))
Chris@214 181 fi;
Chris@5 182
Chris@255 183 swapij =
Chris@255 184 map do { i, j, v }: { i = j, j = i, v } done;
Chris@255 185
Chris@257 186 //!!! should use { row = , column = , value = } instead of i, j, v?
Chris@235 187 enumerateSparse m =
Chris@240 188 (enumerate { values, indices, pointers } =
Chris@240 189 concat
Chris@240 190 (map do i:
Chris@240 191 start = pointers[i];
Chris@240 192 end = pointers[i+1];
Chris@240 193 map2 do j v: { i, j, v } done
Chris@240 194 (slice indices start end)
Chris@240 195 (vec.list (vec.slice values start end))
Chris@240 196 done [0..length pointers - 2]);
Chris@235 197 case m of
Chris@255 198 SparseCSC d: swapij (enumerate d);
Chris@255 199 SparseCSR d: enumerate d;
Chris@237 200 _: [];
Chris@235 201 esac);
Chris@235 202
Chris@255 203 enumerateDense m =
Chris@255 204 (enumerate d =
Chris@255 205 concat
Chris@255 206 (map do i:
Chris@255 207 vv = d[i];
Chris@255 208 map2 do j v: { i, j, v } done
Chris@255 209 [0..vec.length vv - 1]
Chris@255 210 (vec.list vv);
Chris@255 211 done [0..length d - 1]);
Chris@255 212 case m of
Chris@255 213 DenseCols c: swapij (enumerate c);
Chris@255 214 DenseRows r: enumerate r;
Chris@255 215 _: [];
Chris@255 216 esac);
Chris@255 217
Chris@255 218 enumerate m =
Chris@255 219 if isSparse? m then enumerateSparse m else enumerateDense m fi;
Chris@255 220
Chris@261 221 // Make a sparse matrix from entries whose i, j values are known to be
Chris@261 222 // within range
Chris@238 223 makeSparse type size data =
Chris@244 224 (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
Chris@238 225 ordered =
Chris@238 226 sortBy do a b:
Chris@238 227 if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
Chris@238 228 done
Chris@238 229 (map
Chris@238 230 if isRow then
Chris@238 231 do { i, j, v }: { maj = i, min = j, v } done;
Chris@238 232 else
Chris@238 233 do { i, j, v }: { maj = j, min = i, v } done;
Chris@238 234 fi
Chris@255 235 (filter do d: d.v != 0 done data));
Chris@238 236 tagger = if isRow then SparseCSR else SparseCSC fi;
Chris@238 237 majorSize = if isRow then size.rows else size.columns fi;
Chris@238 238 minorSize = if isRow then size.columns else size.rows fi;
Chris@251 239 pointers = array [0];
Chris@251 240 setArrayCapacity pointers (size.rows + 1);
Chris@251 241 fillPointers n i data =
Chris@251 242 if n < majorSize then
Chris@251 243 case data of
Chris@238 244 d::rest:
Chris@251 245 (for [n..d-1] \(push pointers i);
Chris@251 246 fillPointers d (i+1) rest);
Chris@251 247 _:
Chris@251 248 for [n..majorSize-1] \(push pointers i);
Chris@238 249 esac;
Chris@238 250 fi;
Chris@251 251 fillPointers 0 0 (map (.maj) ordered);
Chris@238 252 tagger {
Chris@238 253 values = vec.fromList (map (.v) ordered),
Chris@238 254 indices = array (map (.min) ordered),
Chris@251 255 pointers,
Chris@238 256 extent = minorSize,
Chris@238 257 });
Chris@238 258
Chris@261 259 // Make a sparse matrix from entries that may contain out-of-range
Chris@261 260 // cells which need to be filtered out. This is the public API for
Chris@261 261 // makeSparse and is also used to discard out-of-range cells from
Chris@261 262 // resizedTo.
Chris@261 263 newSparseMatrix type size data =
Chris@261 264 makeSparse type size
Chris@261 265 (filter
Chris@261 266 do { i, j, v }:
Chris@261 267 i == int i and i >= 0 and i < size.rows and
Chris@261 268 j == int j and j >= 0 and j < size.columns
Chris@261 269 done data);
Chris@261 270
Chris@243 271 toSparse m =
Chris@238 272 if isSparse? m then m
Chris@238 273 else
Chris@261 274 makeSparse (typeOf m) (size m) (enumerateDense m);
Chris@238 275 fi;
Chris@238 276
Chris@238 277 toDense m =
Chris@238 278 if not (isSparse? m) then m
Chris@238 279 elif isRowMajor? m then
Chris@238 280 DenseRows (array (map do row: getRow row m done [0..height m - 1]));
Chris@238 281 else
Chris@238 282 DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
Chris@238 283 fi;
Chris@235 284
Chris@20 285 constMatrix n = generate do row col: n done;
Chris@20 286 randomMatrix = generate do row col: Math#random() done;
Chris@5 287 identityMatrix = constMatrix 1;
Chris@5 288
Chris@100 289 transposed m =
Chris@208 290 case m of
Chris@234 291 DenseRows d: DenseCols d;
Chris@234 292 DenseCols d: DenseRows d;
Chris@236 293 SparseCSR d: SparseCSC d;
Chris@236 294 SparseCSC d: SparseCSR d;
Chris@208 295 esac;
Chris@100 296
Chris@100 297 flipped m =
Chris@235 298 if isSparse? m then
Chris@261 299 makeSparse (flippedTypeOf m) (size m) (enumerateSparse m)
Chris@100 300 else
Chris@235 301 if isRowMajor? m then
Chris@260 302 generate do row col: at' m row col done (size m);
Chris@235 303 else
Chris@235 304 transposed
Chris@260 305 (generate do row col: at' m col row done
Chris@235 306 { rows = (width m), columns = (height m) });
Chris@235 307 fi
Chris@100 308 fi;
Chris@100 309
Chris@161 310 toRowMajor m =
Chris@208 311 if isRowMajor? m then m else flipped m fi;
Chris@161 312
Chris@161 313 toColumnMajor m =
Chris@208 314 if not isRowMajor? m then m else flipped m fi;
Chris@161 315
Chris@238 316 equal'' comparator vecComparator m1 m2 =
Chris@238 317 // Prerequisite: m1 and m2 have same sparse-p and storage order
Chris@241 318 (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
Chris@238 319 compareSparse d1 d2 =
Chris@238 320 d1.extent == d2.extent and
Chris@238 321 vecComparator d1.values d2.values and
Chris@241 322 d1.indices == d2.indices and
Chris@241 323 d1.pointers == d2.pointers;
Chris@238 324 case m1 of
Chris@238 325 DenseRows d1:
Chris@238 326 case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
Chris@238 327 DenseCols d1:
Chris@238 328 case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
Chris@238 329 SparseCSR d1:
Chris@238 330 case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
Chris@238 331 SparseCSC d1:
Chris@238 332 case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
Chris@238 333 esac);
Chris@238 334
Chris@238 335 equal' comparator vecComparator m1 m2 =
Chris@226 336 if size m1 != size m2 then
Chris@226 337 false
Chris@226 338 elif isRowMajor? m1 != isRowMajor? m2 then
Chris@238 339 equal' comparator vecComparator (flipped m1) m2;
Chris@238 340 elif isSparse? m1 != isSparse? m2 then
Chris@238 341 if isSparse? m1 then
Chris@243 342 equal' comparator vecComparator m1 (toSparse m2)
Chris@238 343 else
Chris@243 344 equal' comparator vecComparator (toSparse m1) m2
Chris@238 345 fi
Chris@100 346 else
Chris@238 347 equal'' comparator vecComparator m1 m2
Chris@100 348 fi;
Chris@97 349
Chris@228 350 // Compare matrices using the given comparator for individual cells.
Chris@228 351 // Note that matrices with different storage order but the same
Chris@228 352 // contents are equal, although comparing them is slow.
Chris@249 353 //!!! Document the fact that sparse matrices can only be equal if they
Chris@249 354 // have the same set of non-zero cells (regardless of comparator used)
Chris@228 355 equalUnder comparator =
Chris@238 356 equal' comparator (vec.equalUnder comparator);
Chris@228 357
Chris@228 358 equal =
Chris@238 359 equal' (==) vec.equal;
Chris@226 360
Chris@163 361 newMatrix type data = //!!! NB does not copy data
Chris@238 362 (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
Chris@214 363 if empty? data or vec.empty? (head data)
Chris@201 364 then zeroSizeMatrix ()
Chris@208 365 else tagger (array data)
Chris@96 366 fi);
Chris@96 367
Chris@163 368 newRowVector data = //!!! NB does not copy data
Chris@234 369 DenseRows (array [data]);
Chris@96 370
Chris@163 371 newColumnVector data = //!!! NB does not copy data
Chris@234 372 DenseCols (array [data]);
Chris@8 373
Chris@257 374 denseLinearOp op m1 m2 =
Chris@257 375 if isRowMajor? m1 then
Chris@261 376 newMatrix (typeOf m1)
Chris@257 377 (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
Chris@257 378 else
Chris@261 379 newMatrix (typeOf m1)
Chris@257 380 (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
Chris@257 381 fi;
Chris@257 382
Chris@257 383 sparseSumOrDifference op m1 m2 =
Chris@257 384 (h = [:];
Chris@257 385 for (enumerate m1) do { i, j, v }:
Chris@257 386 if not (i in h) then h[i] := [:] fi;
Chris@257 387 h[i][j] := v;
Chris@257 388 done;
Chris@257 389 for (enumerate m2) do { i, j, v }:
Chris@257 390 if not (i in h) then h[i] := [:] fi;
Chris@257 391 if j in h[i] then h[i][j] := op h[i][j] v;
Chris@257 392 else h[i][j] := op 0 v;
Chris@257 393 fi;
Chris@257 394 done;
Chris@257 395 entries = concat
Chris@257 396 (map do i:
Chris@257 397 kk = keys h[i];
Chris@257 398 map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
Chris@257 399 done (keys h));
Chris@261 400 makeSparse (typeOf m1) (size m1) entries);
Chris@257 401
Chris@98 402 sum' m1 m2 =
Chris@208 403 if (size m1) != (size m2)
Chris@208 404 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
Chris@257 405 elif isSparse? m1 and isSparse? m2 then
Chris@257 406 sparseSumOrDifference (+) m1 m2;
Chris@98 407 else
Chris@257 408 denseLinearOp bf.add m1 m2;
Chris@98 409 fi;
Chris@98 410
Chris@257 411 difference m1 m2 =
Chris@229 412 if (size m1) != (size m2)
Chris@229 413 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
Chris@257 414 elif isSparse? m1 and isSparse? m2 then
Chris@257 415 sparseSumOrDifference (-) m1 m2;
Chris@229 416 else
Chris@257 417 denseLinearOp bf.subtract m1 m2;
Chris@257 418 fi;
Chris@257 419
Chris@257 420 scaled factor m =
Chris@257 421 if isSparse? m then
Chris@261 422 makeSparse (typeOf m) (size m)
Chris@261 423 (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
Chris@257 424 elif isRowMajor? m then
Chris@261 425 newMatrix (typeOf m) (map (bf.scaled factor) (asRows m));
Chris@257 426 else
Chris@261 427 newMatrix (typeOf m) (map (bf.scaled factor) (asColumns m));
Chris@229 428 fi;
Chris@229 429
Chris@229 430 abs' m =
Chris@257 431 if isSparse? m then
Chris@261 432 makeSparse (typeOf m) (size m)
Chris@261 433 (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
Chris@257 434 elif isRowMajor? m then
Chris@261 435 newMatrix (typeOf m) (map bf.abs (asRows m));
Chris@257 436 else
Chris@261 437 newMatrix (typeOf m) (map bf.abs (asColumns m));
Chris@257 438 fi;
Chris@229 439
Chris@258 440 filter f m =
Chris@258 441 if isSparse? m then
Chris@261 442 makeSparse (typeOf m) (size m)
Chris@261 443 (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
Chris@261 444 (enumerate m))
Chris@258 445 else
Chris@258 446 vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
Chris@258 447 if isRowMajor? m then
Chris@261 448 newMatrix (typeOf m) (map vfilter (asRows m));
Chris@258 449 else
Chris@261 450 newMatrix (typeOf m) (map vfilter (asColumns m));
Chris@258 451 fi;
Chris@258 452 fi;
Chris@258 453
Chris@249 454 sparseProductLeft size m1 m2 =
Chris@271 455 (d1 = case m1 of
Chris@271 456 SparseCSR d: d;
Chris@271 457 SparseCSC d: d;
Chris@271 458 _: failWith "sparseProductLeft called for non-sparse m1";
Chris@271 459 esac;
Chris@271 460 swap = not isRowMajor? m1;
Chris@249 461 data = array (map \(new double[size.rows]) [1..size.columns]);
Chris@271 462 vv = d1.values;
Chris@271 463 ii = d1.indices;
Chris@271 464 pp = d1.pointers;
Chris@249 465 for [0..size.columns - 1] do j':
Chris@249 466 c = getColumn j' m2;
Chris@271 467 var i = 0;
Chris@271 468 for [0..length ii - 1] do ix:
Chris@271 469 j = ii[ix];
Chris@271 470 ix == pp[i+1] loop (i := i + 1);
Chris@271 471 if swap then
Chris@271 472 data[j'][j] := data[j'][j] + (vec.at vv ix) * (vec.at c i);
Chris@271 473 else
Chris@271 474 data[j'][i] := data[j'][i] + (vec.at vv ix) * (vec.at c j);
Chris@271 475 fi;
Chris@249 476 done;
Chris@249 477 done;
Chris@249 478 DenseCols (array (map vec.vector (list data))));
Chris@249 479
Chris@249 480 sparseProductRight size m1 m2 =
Chris@249 481 (e = enumerateSparse m2;
Chris@249 482 data = array (map \(new double[size.columns]) [1..size.rows]);
Chris@249 483 for [0..size.rows - 1] do i':
Chris@249 484 r = getRow i' m1;
Chris@249 485 for e do { v, i, j }:
Chris@260 486 data[i'][j] := data[i'][j] + v * (vec.at r i);
Chris@249 487 done;
Chris@249 488 done;
Chris@249 489 DenseRows (array (map vec.vector (list data))));
Chris@249 490
Chris@251 491 sparseProduct size m1 m2 =
Chris@251 492 case m2 of
Chris@251 493 SparseCSC d:
Chris@251 494 (e = enumerateSparse m1;
Chris@251 495 entries =
Chris@251 496 (map do j':
Chris@251 497 cs = sparseSlice j' d;
Chris@252 498 hin = mapIntoHash
Chris@260 499 (at cs.indices) (vec.at cs.values)
Chris@252 500 [0..length cs.indices - 1];
Chris@252 501 hout = [:];
Chris@252 502 for e do { v, i, j }:
Chris@252 503 if j in hin then
Chris@252 504 p = v * hin[j];
Chris@252 505 hout[i] := p + (if i in hout then hout[i] else 0 fi);
Chris@252 506 fi
Chris@252 507 done;
Chris@252 508 map do i:
Chris@252 509 { i, j = j', v = hout[i] }
Chris@252 510 done (keys hout);
Chris@252 511 done (nonEmptySlices d));
Chris@251 512 makeSparse (ColumnMajor ()) size (concat entries));
Chris@251 513 SparseCSR _:
Chris@251 514 sparseProduct size m1 (flipped m2);
Chris@251 515 _: failWith "sparseProduct called for non-sparse matrices";
Chris@251 516 esac;
Chris@251 517
Chris@249 518 denseProduct size m1 m2 =
Chris@249 519 (data = array (map \(new double[size.rows]) [1..size.columns]);
Chris@249 520 for [0..size.rows - 1] do i:
Chris@249 521 row = getRow i m1;
Chris@249 522 for [0..size.columns - 1] do j:
Chris@249 523 data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
Chris@249 524 done;
Chris@249 525 done;
Chris@249 526 DenseCols (array (map vec.vector (list data))));
Chris@249 527
Chris@98 528 product m1 m2 =
Chris@208 529 if (size m1).columns != (size m2).rows
Chris@246 530 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
Chris@249 531 else
Chris@249 532 size = { rows = (size m1).rows, columns = (size m2).columns };
Chris@249 533 if isSparse? m1 then
Chris@251 534 if isSparse? m2 then
Chris@251 535 sparseProduct size m1 m2
Chris@251 536 else
Chris@251 537 sparseProductLeft size m1 m2
Chris@251 538 fi
Chris@249 539 elif isSparse? m2 then
Chris@249 540 sparseProductRight size m1 m2
Chris@249 541 else
Chris@249 542 denseProduct size m1 m2
Chris@249 543 fi;
Chris@98 544 fi;
Chris@98 545
Chris@178 546 concatAgainstGrain tagger getter counter mm =
Chris@208 547 (n = counter (size (head mm));
Chris@208 548 tagger (array
Chris@177 549 (map do i:
Chris@214 550 vec.concat (map (getter i) mm)
Chris@208 551 done [0..n-1])));
Chris@177 552
Chris@178 553 concatWithGrain tagger getter counter mm =
Chris@208 554 tagger (array
Chris@177 555 (concat
Chris@177 556 (map do m:
Chris@208 557 n = counter (size m);
Chris@208 558 map do i: getter i m done [0..n-1]
Chris@208 559 done mm)));
Chris@177 560
Chris@259 561 sparseConcat direction first mm =
Chris@259 562 (dimension d f = if direction == d then sum (map f mm) else f first fi;
Chris@259 563 rows = dimension (Vertical ()) height;
Chris@259 564 columns = dimension (Horizontal ()) width;
Chris@259 565 entries ioff joff ui uj mm =
Chris@259 566 case mm of
Chris@259 567 m::rest:
Chris@259 568 (map do { i, j, v }: { i = i + ioff, j = j + joff, v }
Chris@259 569 done (enumerate m)) ++
Chris@259 570 (entries
Chris@259 571 (ioff + ui * height m)
Chris@259 572 (joff + uj * width m)
Chris@259 573 ui uj rest);
Chris@259 574 _: []
Chris@259 575 esac;
Chris@261 576 makeSparse (typeOf first) { rows, columns }
Chris@259 577 if direction == Vertical () then entries 0 0 1 0 mm
Chris@259 578 else entries 0 0 0 1 mm fi);
Chris@259 579
Chris@178 580 checkDimensionsFor direction first mm =
Chris@178 581 (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
Chris@208 582 n = counter (size first);
Chris@208 583 if not (all id (map do m: counter (size m) == n done mm)) then
Chris@208 584 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
Chris@178 585 fi);
Chris@178 586
Chris@187 587 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
Chris@259 588 case length mm of
Chris@259 589 0: zeroSizeMatrix ();
Chris@259 590 1: head mm;
Chris@259 591 _:
Chris@259 592 first = head mm;
Chris@178 593 checkDimensionsFor direction first mm;
Chris@259 594 if all isSparse? mm then
Chris@259 595 sparseConcat direction first mm
Chris@259 596 else
Chris@259 597 row = isRowMajor? first;
Chris@259 598 // horizontal, row-major: against grain with rows
Chris@259 599 // horizontal, col-major: with grain with cols
Chris@259 600 // vertical, row-major: with grain with rows
Chris@259 601 // vertical, col-major: against grain with cols
Chris@259 602 case direction of
Chris@259 603 Horizontal ():
Chris@259 604 if row then concatAgainstGrain DenseRows getRow (.rows) mm;
Chris@259 605 else concatWithGrain DenseCols getColumn (.columns) mm;
Chris@259 606 fi;
Chris@259 607 Vertical ():
Chris@259 608 if row then concatWithGrain DenseRows getRow (.rows) mm;
Chris@259 609 else concatAgainstGrain DenseCols getColumn (.columns) mm;
Chris@259 610 fi;
Chris@259 611 esac;
Chris@259 612 fi;
Chris@190 613 esac;
Chris@177 614
Chris@260 615 //!!! doc note: argument order chosen for consistency with std module slice
Chris@260 616 rowSlice m start end = //!!! doc: storage order same as input
Chris@208 617 if isRowMajor? m then
Chris@260 618 DenseRows (array (map ((flip getRow) m) [start .. end - 1]))
Chris@187 619 else
Chris@260 620 DenseCols (array (map do v: vec.slice v start end done (asColumns m)))
Chris@187 621 fi;
Chris@187 622
Chris@260 623 //!!! doc note: argument order chosen for consistency with std module slice
Chris@260 624 columnSlice m start end = //!!! doc: storage order same as input
Chris@208 625 if not isRowMajor? m then
Chris@260 626 DenseCols (array (map ((flip getColumn) m) [start .. end - 1]))
Chris@187 627 else
Chris@260 628 DenseRows (array (map do v: vec.slice v start end done (asRows m)))
Chris@187 629 fi;
Chris@187 630
Chris@201 631 resizedTo newsize m =
Chris@208 632 (if newsize == (size m) then
Chris@201 633 m
Chris@261 634 elif isSparse? m then
Chris@261 635 // don't call makeSparse directly: want to discard
Chris@261 636 // out-of-range cells
Chris@261 637 newSparseMatrix (typeOf m) newsize (enumerateSparse m)
Chris@208 638 elif (height m) == 0 or (width m) == 0 then
Chris@202 639 zeroMatrixWithTypeOf m newsize;
Chris@201 640 else
Chris@208 641 growrows = newsize.rows - (height m);
Chris@208 642 growcols = newsize.columns - (width m);
Chris@208 643 rowm = isRowMajor? m;
Chris@201 644 resizedTo newsize
Chris@201 645 if rowm and growrows < 0 then
Chris@260 646 rowSlice m 0 newsize.rows
Chris@201 647 elif (not rowm) and growcols < 0 then
Chris@260 648 columnSlice m 0 newsize.columns
Chris@201 649 elif growrows < 0 then
Chris@260 650 rowSlice m 0 newsize.rows
Chris@201 651 elif growcols < 0 then
Chris@260 652 columnSlice m 0 newsize.columns
Chris@201 653 else
Chris@201 654 if growrows > 0 then
Chris@201 655 concat (Vertical ())
Chris@208 656 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
Chris@201 657 else
Chris@201 658 concat (Horizontal ())
Chris@208 659 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
Chris@201 660 fi
Chris@201 661 fi
Chris@202 662 fi);
Chris@201 663
Chris@5 664 {
Chris@208 665 size,
Chris@208 666 width,
Chris@208 667 height,
Chris@246 668 density,
Chris@249 669 nonZeroValues,
Chris@260 670 at = at',
Chris@208 671 getColumn,
Chris@208 672 getRow,
Chris@208 673 isRowMajor?,
Chris@234 674 isSparse?,
Chris@208 675 generate,
Chris@208 676 constMatrix,
Chris@208 677 randomMatrix,
Chris@208 678 zeroMatrix,
Chris@208 679 identityMatrix,
Chris@208 680 zeroSizeMatrix,
Chris@208 681 equal,
Chris@226 682 equalUnder,
Chris@208 683 transposed,
Chris@208 684 flipped,
Chris@208 685 toRowMajor,
Chris@208 686 toColumnMajor,
Chris@238 687 toSparse,
Chris@238 688 toDense,
Chris@208 689 scaled,
Chris@208 690 resizedTo,
Chris@208 691 asRows,
Chris@208 692 asColumns,
Chris@208 693 sum = sum',
Chris@229 694 difference,
Chris@229 695 abs = abs',
Chris@258 696 filter,
Chris@208 697 product,
Chris@208 698 concat,
Chris@208 699 rowSlice,
Chris@208 700 columnSlice,
Chris@208 701 newMatrix,
Chris@208 702 newRowVector,
Chris@208 703 newColumnVector,
Chris@261 704 newSparseMatrix,
Chris@255 705 enumerate
Chris@208 706 }
Chris@208 707 as
Chris@208 708 {
Chris@208 709 //!!! check whether these are right to be .selector rather than just selector
Chris@208 710
Chris@208 711 size is matrix -> { .rows is number, .columns is number },
Chris@208 712 width is matrix -> number,
Chris@208 713 height is matrix -> number,
Chris@246 714 density is matrix -> number,
Chris@249 715 nonZeroValues is matrix -> number,
Chris@260 716 at is matrix -> number -> number -> number,
Chris@214 717 getColumn is number -> matrix -> vector,
Chris@214 718 getRow is number -> matrix -> vector,
Chris@208 719 isRowMajor? is matrix -> boolean,
Chris@234 720 isSparse? is matrix -> boolean,
Chris@195 721 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
Chris@195 722 constMatrix is number -> { .rows is number, .columns is number } -> matrix,
Chris@195 723 randomMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 724 zeroMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 725 identityMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 726 zeroSizeMatrix is () -> matrix,
Chris@195 727 equal is matrix -> matrix -> boolean,
Chris@226 728 equalUnder is (number -> number -> boolean) -> matrix -> matrix -> boolean,
Chris@195 729 transposed is matrix -> matrix,
Chris@195 730 flipped is matrix -> matrix,
Chris@195 731 toRowMajor is matrix -> matrix,
Chris@195 732 toColumnMajor is matrix -> matrix,
Chris@243 733 toSparse is matrix -> matrix,
Chris@238 734 toDense is matrix -> matrix,
Chris@195 735 scaled is number -> matrix -> matrix,
Chris@243 736 thresholded is number -> matrix -> matrix,
Chris@195 737 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
Chris@214 738 asRows is matrix -> list<vector>,
Chris@214 739 asColumns is matrix -> list<vector>,
Chris@208 740 sum is matrix -> matrix -> matrix,
Chris@229 741 difference is matrix -> matrix -> matrix,
Chris@229 742 abs is matrix -> matrix,
Chris@258 743 filter is (number -> boolean) -> matrix -> matrix,
Chris@195 744 product is matrix -> matrix -> matrix,
Chris@195 745 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
Chris@260 746 rowSlice is matrix -> number -> number -> matrix,
Chris@260 747 columnSlice is matrix -> number -> number -> matrix,
Chris@214 748 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix,
Chris@214 749 newRowVector is vector -> matrix,
Chris@214 750 newColumnVector is vector -> matrix,
Chris@255 751 newSparseMatrix is (ColumnMajor () | RowMajor ()) -> { .rows is number, .columns is number } -> list<{ .i is number, .j is number, .v is number }> -> matrix,
Chris@255 752 enumerate is matrix -> list<{ .i is number, .j is number, .v is number }>
Chris@5 753 }
Chris@5 754