annotate yetilab/matrix/matrix.yeti @ 272:2ebda6646c40

Quicker, though uglier, sparse products
author Chris Cannam
date Thu, 23 May 2013 17:15:27 +0100
parents c206de7c3018
children
rev   line source
Chris@5 1
Chris@94 2 module yetilab.matrix.matrix;
Chris@94 3
Chris@214 4 // A matrix is an array of vectors.
Chris@94 5
Chris@101 6 // A matrix can be stored in either column-major (the default) or
Chris@101 7 // row-major format. Storage order is an efficiency concern only:
Chris@101 8 // every API function operating on matrix objects will return the same
Chris@101 9 // result regardless of storage order. (The transpose function just
Chris@101 10 // switches the row/column order without moving the elements.)
Chris@18 11
Chris@222 12 vec = load yetilab.vector.vector;
Chris@222 13 bf = load yetilab.vector.blockfuncs;
Chris@9 14
Chris@222 15 load yetilab.vector.vectortype;
Chris@195 16 load yetilab.matrix.matrixtype;
Chris@195 17
Chris@208 18 size m =
Chris@208 19 case m of
Chris@234 20 DenseRows r:
Chris@208 21 major = length r;
Chris@208 22 {
Chris@208 23 rows = major,
Chris@214 24 columns = if major > 0 then vec.length r[0] else 0 fi,
Chris@208 25 };
Chris@234 26 DenseCols c:
Chris@208 27 major = length c;
Chris@208 28 {
Chris@214 29 rows = if major > 0 then vec.length c[0] else 0 fi,
Chris@208 30 columns = major,
Chris@208 31 };
Chris@237 32 SparseCSR { values, indices, pointers, extent }:
Chris@234 33 {
Chris@236 34 rows = (length pointers) - 1,
Chris@237 35 columns = extent
Chris@234 36 };
Chris@237 37 SparseCSC { values, indices, pointers, extent }:
Chris@234 38 {
Chris@237 39 rows = extent,
Chris@236 40 columns = (length pointers) - 1
Chris@234 41 };
Chris@208 42 esac;
Chris@208 43
Chris@208 44 width m = (size m).columns;
Chris@208 45 height m = (size m).rows;
Chris@208 46
Chris@249 47 nonZeroValues m =
Chris@249 48 (nz d =
Chris@242 49 sum
Chris@242 50 (map do v:
Chris@242 51 sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
Chris@242 52 done d);
Chris@242 53 case m of
Chris@249 54 DenseRows d: nz d;
Chris@249 55 DenseCols d: nz d;
Chris@249 56 SparseCSR d: vec.length d.values;
Chris@249 57 SparseCSC d: vec.length d.values;
Chris@242 58 esac);
Chris@242 59
Chris@249 60 density m =
Chris@249 61 ({ rows, columns } = size m;
Chris@249 62 cells = rows * columns;
Chris@249 63 (nonZeroValues m) / cells);
Chris@249 64
Chris@236 65 sparseSlice n d =
Chris@236 66 (start = d.pointers[n];
Chris@236 67 end = d.pointers[n+1];
Chris@236 68 {
Chris@236 69 values = vec.slice d.values start end,
Chris@239 70 indices = slice d.indices start end,
Chris@236 71 });
Chris@236 72
Chris@252 73 nonEmptySlices d =
Chris@252 74 (ne = array [];
Chris@252 75 for [0..length d.pointers - 2] do i:
Chris@252 76 if d.pointers[i] != d.pointers[i+1] then
Chris@252 77 push ne i
Chris@252 78 fi
Chris@252 79 done;
Chris@252 80 ne);
Chris@252 81
Chris@236 82 fromSlice n m d =
Chris@236 83 (slice = sparseSlice n d;
Chris@236 84 var v = 0;
Chris@236 85 for [0..length slice.indices - 1] do i:
Chris@236 86 if slice.indices[i] == m then
Chris@260 87 v := vec.at slice.values i;
Chris@236 88 fi
Chris@236 89 done;
Chris@236 90 v);
Chris@236 91
Chris@236 92 filledSlice n d =
Chris@236 93 (slice = sparseSlice n d;
Chris@236 94 dslice = new double[d.extent];
Chris@239 95 for [0..length slice.indices - 1] do i:
Chris@260 96 dslice[slice.indices[i]] := vec.at slice.values i;
Chris@239 97 done;
Chris@236 98 vec.vector dslice);
Chris@236 99
Chris@260 100 at' m row col =
Chris@208 101 case m of
Chris@260 102 DenseRows rows: r = rows[row]; vec.at r col;
Chris@260 103 DenseCols cols: c = cols[col]; vec.at c row;
Chris@236 104 SparseCSR data: fromSlice row col data;
Chris@236 105 SparseCSC data: fromSlice col row data;
Chris@208 106 esac;
Chris@208 107
Chris@208 108 getColumn j m =
Chris@208 109 case m of
Chris@234 110 DenseCols cols: cols[j];
Chris@236 111 SparseCSC data: filledSlice j data;
Chris@260 112 _: vec.fromList (map do i: at' m i j done [0..height m - 1]);
Chris@208 113 esac;
Chris@208 114
Chris@208 115 getRow i m =
Chris@208 116 case m of
Chris@234 117 DenseRows rows: rows[i];
Chris@236 118 SparseCSR data: filledSlice i data;
Chris@260 119 _: vec.fromList (map do j: at' m i j done [0..width m - 1]);
Chris@208 120 esac;
Chris@208 121
Chris@257 122 asRows m =
Chris@257 123 map do i: getRow i m done [0 .. (height m) - 1];
Chris@257 124
Chris@257 125 asColumns m =
Chris@257 126 map do i: getColumn i m done [0 .. (width m) - 1];
Chris@257 127
Chris@208 128 isRowMajor? m =
Chris@208 129 case m of
Chris@234 130 DenseRows _: true;
Chris@234 131 DenseCols _: false;
Chris@234 132 SparseCSR _: true;
Chris@234 133 SparseCSC _: false;
Chris@234 134 esac;
Chris@234 135
Chris@234 136 isSparse? m =
Chris@234 137 case m of
Chris@234 138 DenseRows _: false;
Chris@234 139 DenseCols _: false;
Chris@234 140 SparseCSR _: true;
Chris@234 141 SparseCSC _: true;
Chris@208 142 esac;
Chris@94 143
Chris@261 144 typeOf m =
Chris@261 145 if isRowMajor? m then RowMajor ()
Chris@261 146 else ColumnMajor ()
Chris@261 147 fi;
Chris@261 148
Chris@261 149 flippedTypeOf m =
Chris@261 150 if isRowMajor? m then ColumnMajor ()
Chris@261 151 else RowMajor ()
Chris@261 152 fi;
Chris@261 153
Chris@244 154 newColumnMajorStorage { rows, columns } =
Chris@97 155 if rows < 1 then array []
Chris@214 156 else array (map \(vec.zeros rows) [1..columns])
Chris@97 157 fi;
Chris@94 158
Chris@98 159 zeroMatrix { rows, columns } =
Chris@244 160 DenseCols (newColumnMajorStorage { rows, columns });
Chris@201 161
Chris@201 162 zeroMatrixWithTypeOf m { rows, columns } =
Chris@208 163 if isRowMajor? m then
Chris@244 164 DenseRows (newColumnMajorStorage { rows = columns, columns = rows });
Chris@201 165 else
Chris@244 166 DenseCols (newColumnMajorStorage { rows, columns });
Chris@201 167 fi;
Chris@5 168
Chris@214 169 zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
Chris@214 170
Chris@98 171 generate f { rows, columns } =
Chris@214 172 if rows < 1 or columns < 1 then zeroSizeMatrix ()
Chris@214 173 else
Chris@214 174 m = array (map \(new double[rows]) [1..columns]);
Chris@214 175 for [0..columns-1] do col:
Chris@214 176 for [0..rows-1] do row:
Chris@214 177 m[col][row] := f row col;
Chris@214 178 done;
Chris@5 179 done;
Chris@234 180 DenseCols (array (map vec.vector m))
Chris@214 181 fi;
Chris@5 182
Chris@255 183 swapij =
Chris@255 184 map do { i, j, v }: { i = j, j = i, v } done;
Chris@255 185
Chris@257 186 //!!! should use { row = , column = , value = } instead of i, j, v?
Chris@235 187 enumerateSparse m =
Chris@240 188 (enumerate { values, indices, pointers } =
Chris@240 189 concat
Chris@240 190 (map do i:
Chris@240 191 start = pointers[i];
Chris@240 192 end = pointers[i+1];
Chris@240 193 map2 do j v: { i, j, v } done
Chris@240 194 (slice indices start end)
Chris@240 195 (vec.list (vec.slice values start end))
Chris@240 196 done [0..length pointers - 2]);
Chris@235 197 case m of
Chris@255 198 SparseCSC d: swapij (enumerate d);
Chris@255 199 SparseCSR d: enumerate d;
Chris@237 200 _: [];
Chris@235 201 esac);
Chris@235 202
Chris@255 203 enumerateDense m =
Chris@255 204 (enumerate d =
Chris@255 205 concat
Chris@255 206 (map do i:
Chris@255 207 vv = d[i];
Chris@255 208 map2 do j v: { i, j, v } done
Chris@255 209 [0..vec.length vv - 1]
Chris@255 210 (vec.list vv);
Chris@255 211 done [0..length d - 1]);
Chris@255 212 case m of
Chris@255 213 DenseCols c: swapij (enumerate c);
Chris@255 214 DenseRows r: enumerate r;
Chris@255 215 _: [];
Chris@255 216 esac);
Chris@255 217
Chris@255 218 enumerate m =
Chris@255 219 if isSparse? m then enumerateSparse m else enumerateDense m fi;
Chris@255 220
Chris@261 221 // Make a sparse matrix from entries whose i, j values are known to be
Chris@261 222 // within range
Chris@238 223 makeSparse type size data =
Chris@244 224 (isRow = case type of RowMajor (): true; ColumnMajor (): false esac;
Chris@238 225 ordered =
Chris@238 226 sortBy do a b:
Chris@238 227 if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
Chris@238 228 done
Chris@238 229 (map
Chris@238 230 if isRow then
Chris@238 231 do { i, j, v }: { maj = i, min = j, v } done;
Chris@238 232 else
Chris@238 233 do { i, j, v }: { maj = j, min = i, v } done;
Chris@238 234 fi
Chris@255 235 (filter do d: d.v != 0 done data));
Chris@238 236 tagger = if isRow then SparseCSR else SparseCSC fi;
Chris@238 237 majorSize = if isRow then size.rows else size.columns fi;
Chris@238 238 minorSize = if isRow then size.columns else size.rows fi;
Chris@251 239 pointers = array [0];
Chris@251 240 setArrayCapacity pointers (size.rows + 1);
Chris@251 241 fillPointers n i data =
Chris@251 242 if n < majorSize then
Chris@251 243 case data of
Chris@238 244 d::rest:
Chris@251 245 (for [n..d-1] \(push pointers i);
Chris@251 246 fillPointers d (i+1) rest);
Chris@251 247 _:
Chris@251 248 for [n..majorSize-1] \(push pointers i);
Chris@238 249 esac;
Chris@238 250 fi;
Chris@251 251 fillPointers 0 0 (map (.maj) ordered);
Chris@238 252 tagger {
Chris@238 253 values = vec.fromList (map (.v) ordered),
Chris@238 254 indices = array (map (.min) ordered),
Chris@251 255 pointers,
Chris@238 256 extent = minorSize,
Chris@238 257 });
Chris@238 258
Chris@261 259 // Make a sparse matrix from entries that may contain out-of-range
Chris@261 260 // cells which need to be filtered out. This is the public API for
Chris@261 261 // makeSparse and is also used to discard out-of-range cells from
Chris@261 262 // resizedTo.
Chris@261 263 newSparseMatrix type size data =
Chris@261 264 makeSparse type size
Chris@261 265 (filter
Chris@261 266 do { i, j, v }:
Chris@261 267 i == int i and i >= 0 and i < size.rows and
Chris@261 268 j == int j and j >= 0 and j < size.columns
Chris@261 269 done data);
Chris@261 270
Chris@243 271 toSparse m =
Chris@238 272 if isSparse? m then m
Chris@238 273 else
Chris@261 274 makeSparse (typeOf m) (size m) (enumerateDense m);
Chris@238 275 fi;
Chris@238 276
Chris@238 277 toDense m =
Chris@238 278 if not (isSparse? m) then m
Chris@238 279 elif isRowMajor? m then
Chris@238 280 DenseRows (array (map do row: getRow row m done [0..height m - 1]));
Chris@238 281 else
Chris@238 282 DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
Chris@238 283 fi;
Chris@235 284
Chris@20 285 constMatrix n = generate do row col: n done;
Chris@20 286 randomMatrix = generate do row col: Math#random() done;
Chris@5 287 identityMatrix = constMatrix 1;
Chris@5 288
Chris@100 289 transposed m =
Chris@208 290 case m of
Chris@234 291 DenseRows d: DenseCols d;
Chris@234 292 DenseCols d: DenseRows d;
Chris@236 293 SparseCSR d: SparseCSC d;
Chris@236 294 SparseCSC d: SparseCSR d;
Chris@208 295 esac;
Chris@100 296
Chris@100 297 flipped m =
Chris@235 298 if isSparse? m then
Chris@261 299 makeSparse (flippedTypeOf m) (size m) (enumerateSparse m)
Chris@100 300 else
Chris@235 301 if isRowMajor? m then
Chris@260 302 generate do row col: at' m row col done (size m);
Chris@235 303 else
Chris@235 304 transposed
Chris@260 305 (generate do row col: at' m col row done
Chris@235 306 { rows = (width m), columns = (height m) });
Chris@235 307 fi
Chris@100 308 fi;
Chris@100 309
Chris@161 310 toRowMajor m =
Chris@208 311 if isRowMajor? m then m else flipped m fi;
Chris@161 312
Chris@161 313 toColumnMajor m =
Chris@208 314 if not isRowMajor? m then m else flipped m fi;
Chris@161 315
Chris@238 316 equal'' comparator vecComparator m1 m2 =
Chris@238 317 // Prerequisite: m1 and m2 have same sparse-p and storage order
Chris@241 318 (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
Chris@238 319 compareSparse d1 d2 =
Chris@238 320 d1.extent == d2.extent and
Chris@238 321 vecComparator d1.values d2.values and
Chris@241 322 d1.indices == d2.indices and
Chris@241 323 d1.pointers == d2.pointers;
Chris@238 324 case m1 of
Chris@238 325 DenseRows d1:
Chris@238 326 case m2 of DenseRows d2: compareVecLists d1 d2; _: false; esac;
Chris@238 327 DenseCols d1:
Chris@238 328 case m2 of DenseCols d2: compareVecLists d1 d2; _: false; esac;
Chris@238 329 SparseCSR d1:
Chris@238 330 case m2 of SparseCSR d2: compareSparse d1 d2; _: false; esac;
Chris@238 331 SparseCSC d1:
Chris@238 332 case m2 of SparseCSC d2: compareSparse d1 d2; _: false; esac;
Chris@238 333 esac);
Chris@238 334
Chris@238 335 equal' comparator vecComparator m1 m2 =
Chris@226 336 if size m1 != size m2 then
Chris@226 337 false
Chris@226 338 elif isRowMajor? m1 != isRowMajor? m2 then
Chris@238 339 equal' comparator vecComparator (flipped m1) m2;
Chris@238 340 elif isSparse? m1 != isSparse? m2 then
Chris@238 341 if isSparse? m1 then
Chris@243 342 equal' comparator vecComparator m1 (toSparse m2)
Chris@238 343 else
Chris@243 344 equal' comparator vecComparator (toSparse m1) m2
Chris@238 345 fi
Chris@100 346 else
Chris@238 347 equal'' comparator vecComparator m1 m2
Chris@100 348 fi;
Chris@97 349
Chris@228 350 // Compare matrices using the given comparator for individual cells.
Chris@228 351 // Note that matrices with different storage order but the same
Chris@228 352 // contents are equal, although comparing them is slow.
Chris@249 353 //!!! Document the fact that sparse matrices can only be equal if they
Chris@249 354 // have the same set of non-zero cells (regardless of comparator used)
Chris@228 355 equalUnder comparator =
Chris@238 356 equal' comparator (vec.equalUnder comparator);
Chris@228 357
Chris@228 358 equal =
Chris@238 359 equal' (==) vec.equal;
Chris@226 360
Chris@163 361 newMatrix type data = //!!! NB does not copy data
Chris@238 362 (tagger = case type of RowMajor (): DenseRows; ColumnMajor (): DenseCols esac;
Chris@214 363 if empty? data or vec.empty? (head data)
Chris@201 364 then zeroSizeMatrix ()
Chris@208 365 else tagger (array data)
Chris@96 366 fi);
Chris@96 367
Chris@163 368 newRowVector data = //!!! NB does not copy data
Chris@234 369 DenseRows (array [data]);
Chris@96 370
Chris@163 371 newColumnVector data = //!!! NB does not copy data
Chris@234 372 DenseCols (array [data]);
Chris@8 373
Chris@257 374 denseLinearOp op m1 m2 =
Chris@257 375 if isRowMajor? m1 then
Chris@261 376 newMatrix (typeOf m1)
Chris@257 377 (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
Chris@257 378 else
Chris@261 379 newMatrix (typeOf m1)
Chris@257 380 (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
Chris@257 381 fi;
Chris@257 382
Chris@257 383 sparseSumOrDifference op m1 m2 =
Chris@257 384 (h = [:];
Chris@257 385 for (enumerate m1) do { i, j, v }:
Chris@257 386 if not (i in h) then h[i] := [:] fi;
Chris@257 387 h[i][j] := v;
Chris@257 388 done;
Chris@257 389 for (enumerate m2) do { i, j, v }:
Chris@257 390 if not (i in h) then h[i] := [:] fi;
Chris@257 391 if j in h[i] then h[i][j] := op h[i][j] v;
Chris@257 392 else h[i][j] := op 0 v;
Chris@257 393 fi;
Chris@257 394 done;
Chris@257 395 entries = concat
Chris@257 396 (map do i:
Chris@257 397 kk = keys h[i];
Chris@257 398 map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
Chris@257 399 done (keys h));
Chris@261 400 makeSparse (typeOf m1) (size m1) entries);
Chris@257 401
Chris@98 402 sum' m1 m2 =
Chris@208 403 if (size m1) != (size m2)
Chris@208 404 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
Chris@257 405 elif isSparse? m1 and isSparse? m2 then
Chris@257 406 sparseSumOrDifference (+) m1 m2;
Chris@98 407 else
Chris@257 408 denseLinearOp bf.add m1 m2;
Chris@98 409 fi;
Chris@98 410
Chris@257 411 difference m1 m2 =
Chris@229 412 if (size m1) != (size m2)
Chris@229 413 then failWith "Matrices are not the same size: \(size m1), \(size m2)";
Chris@257 414 elif isSparse? m1 and isSparse? m2 then
Chris@257 415 sparseSumOrDifference (-) m1 m2;
Chris@229 416 else
Chris@257 417 denseLinearOp bf.subtract m1 m2;
Chris@257 418 fi;
Chris@257 419
Chris@257 420 scaled factor m =
Chris@257 421 if isSparse? m then
Chris@261 422 makeSparse (typeOf m) (size m)
Chris@261 423 (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
Chris@257 424 elif isRowMajor? m then
Chris@261 425 newMatrix (typeOf m) (map (bf.scaled factor) (asRows m));
Chris@257 426 else
Chris@261 427 newMatrix (typeOf m) (map (bf.scaled factor) (asColumns m));
Chris@229 428 fi;
Chris@229 429
Chris@229 430 abs' m =
Chris@257 431 if isSparse? m then
Chris@261 432 makeSparse (typeOf m) (size m)
Chris@261 433 (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
Chris@257 434 elif isRowMajor? m then
Chris@261 435 newMatrix (typeOf m) (map bf.abs (asRows m));
Chris@257 436 else
Chris@261 437 newMatrix (typeOf m) (map bf.abs (asColumns m));
Chris@257 438 fi;
Chris@229 439
Chris@258 440 filter f m =
Chris@258 441 if isSparse? m then
Chris@261 442 makeSparse (typeOf m) (size m)
Chris@261 443 (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
Chris@261 444 (enumerate m))
Chris@258 445 else
Chris@258 446 vfilter = vec.fromList . (map do i: if f i then i else 0 fi done) . vec.list;
Chris@258 447 if isRowMajor? m then
Chris@261 448 newMatrix (typeOf m) (map vfilter (asRows m));
Chris@258 449 else
Chris@261 450 newMatrix (typeOf m) (map vfilter (asColumns m));
Chris@258 451 fi;
Chris@258 452 fi;
Chris@258 453
Chris@249 454 sparseProductLeft size m1 m2 =
Chris@272 455 ({ values, indices, pointers } = case m1 of
Chris@272 456 SparseCSR d: d;
Chris@271 457 SparseCSC d: d;
Chris@271 458 _: failWith "sparseProductLeft called for non-sparse m1";
Chris@271 459 esac;
Chris@272 460 rows = isRowMajor? m1;
Chris@249 461 data = array (map \(new double[size.rows]) [1..size.columns]);
Chris@249 462 for [0..size.columns - 1] do j':
Chris@249 463 c = getColumn j' m2;
Chris@272 464 var p = 0;
Chris@272 465 for [0..length indices - 1] do ix:
Chris@272 466 ix == pointers[p+1] loop (p := p + 1);
Chris@272 467 i = if rows then p else indices[ix] fi;
Chris@272 468 j = if rows then indices[ix] else p fi;
Chris@272 469 data[j'][i] := data[j'][i] + (vec.at values ix) * (vec.at c j);
Chris@249 470 done;
Chris@249 471 done;
Chris@249 472 DenseCols (array (map vec.vector (list data))));
Chris@249 473
Chris@249 474 sparseProductRight size m1 m2 =
Chris@272 475 ({ values, indices, pointers } = case m2 of
Chris@272 476 SparseCSR d: d;
Chris@272 477 SparseCSC d: d;
Chris@272 478 _: failWith "sparseProductLeft called for non-sparse m1";
Chris@272 479 esac;
Chris@272 480 rows = isRowMajor? m2;
Chris@249 481 data = array (map \(new double[size.columns]) [1..size.rows]);
Chris@249 482 for [0..size.rows - 1] do i':
Chris@249 483 r = getRow i' m1;
Chris@272 484 var p = 0;
Chris@272 485 for [0..length indices - 1] do ix:
Chris@272 486 ix == pointers[p+1] loop (p := p + 1);
Chris@272 487 i = if rows then p else indices[ix] fi;
Chris@272 488 j = if rows then indices[ix] else p fi;
Chris@272 489 data[i'][j] := data[i'][j] + (vec.at values ix) * (vec.at r i);
Chris@249 490 done;
Chris@249 491 done;
Chris@249 492 DenseRows (array (map vec.vector (list data))));
Chris@249 493
Chris@251 494 sparseProduct size m1 m2 =
Chris@251 495 case m2 of
Chris@251 496 SparseCSC d:
Chris@272 497 ({ values, indices, pointers } = case m1 of
Chris@272 498 SparseCSR d1: d1;
Chris@272 499 SparseCSC d1: d1;
Chris@272 500 _: failWith "sparseProduct called for non-sparse matrices";
Chris@272 501 esac;
Chris@272 502 rows = isRowMajor? m1;
Chris@272 503 var p = 0;
Chris@272 504 pindices = new int[length indices];
Chris@272 505 for [0..length indices - 1] do ix:
Chris@272 506 ix == pointers[p+1] loop (p := p + 1);
Chris@272 507 pindices[ix] := p;
Chris@272 508 done;
Chris@251 509 entries =
Chris@251 510 (map do j':
Chris@251 511 cs = sparseSlice j' d;
Chris@252 512 hin = mapIntoHash
Chris@260 513 (at cs.indices) (vec.at cs.values)
Chris@252 514 [0..length cs.indices - 1];
Chris@252 515 hout = [:];
Chris@272 516 for [0..length indices - 1] do ix:
Chris@272 517 i = if rows then pindices[ix] else indices[ix] fi;
Chris@272 518 j = if rows then indices[ix] else pindices[ix] fi;
Chris@252 519 if j in hin then
Chris@272 520 p = (vec.at values ix) * hin[j];
Chris@252 521 hout[i] := p + (if i in hout then hout[i] else 0 fi);
Chris@272 522 fi;
Chris@252 523 done;
Chris@252 524 map do i:
Chris@252 525 { i, j = j', v = hout[i] }
Chris@252 526 done (keys hout);
Chris@252 527 done (nonEmptySlices d));
Chris@251 528 makeSparse (ColumnMajor ()) size (concat entries));
Chris@251 529 SparseCSR _:
Chris@251 530 sparseProduct size m1 (flipped m2);
Chris@251 531 _: failWith "sparseProduct called for non-sparse matrices";
Chris@251 532 esac;
Chris@251 533
Chris@249 534 denseProduct size m1 m2 =
Chris@249 535 (data = array (map \(new double[size.rows]) [1..size.columns]);
Chris@249 536 for [0..size.rows - 1] do i:
Chris@249 537 row = getRow i m1;
Chris@249 538 for [0..size.columns - 1] do j:
Chris@249 539 data[j][i] := bf.sum (bf.multiply row (getColumn j m2));
Chris@249 540 done;
Chris@249 541 done;
Chris@249 542 DenseCols (array (map vec.vector (list data))));
Chris@249 543
Chris@98 544 product m1 m2 =
Chris@208 545 if (size m1).columns != (size m2).rows
Chris@246 546 then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
Chris@249 547 else
Chris@249 548 size = { rows = (size m1).rows, columns = (size m2).columns };
Chris@249 549 if isSparse? m1 then
Chris@251 550 if isSparse? m2 then
Chris@251 551 sparseProduct size m1 m2
Chris@251 552 else
Chris@251 553 sparseProductLeft size m1 m2
Chris@251 554 fi
Chris@249 555 elif isSparse? m2 then
Chris@249 556 sparseProductRight size m1 m2
Chris@249 557 else
Chris@249 558 denseProduct size m1 m2
Chris@249 559 fi;
Chris@98 560 fi;
Chris@98 561
Chris@178 562 concatAgainstGrain tagger getter counter mm =
Chris@208 563 (n = counter (size (head mm));
Chris@208 564 tagger (array
Chris@177 565 (map do i:
Chris@214 566 vec.concat (map (getter i) mm)
Chris@208 567 done [0..n-1])));
Chris@177 568
Chris@178 569 concatWithGrain tagger getter counter mm =
Chris@208 570 tagger (array
Chris@177 571 (concat
Chris@177 572 (map do m:
Chris@208 573 n = counter (size m);
Chris@208 574 map do i: getter i m done [0..n-1]
Chris@208 575 done mm)));
Chris@177 576
Chris@259 577 sparseConcat direction first mm =
Chris@259 578 (dimension d f = if direction == d then sum (map f mm) else f first fi;
Chris@259 579 rows = dimension (Vertical ()) height;
Chris@259 580 columns = dimension (Horizontal ()) width;
Chris@259 581 entries ioff joff ui uj mm =
Chris@259 582 case mm of
Chris@259 583 m::rest:
Chris@259 584 (map do { i, j, v }: { i = i + ioff, j = j + joff, v }
Chris@259 585 done (enumerate m)) ++
Chris@259 586 (entries
Chris@259 587 (ioff + ui * height m)
Chris@259 588 (joff + uj * width m)
Chris@259 589 ui uj rest);
Chris@259 590 _: []
Chris@259 591 esac;
Chris@261 592 makeSparse (typeOf first) { rows, columns }
Chris@259 593 if direction == Vertical () then entries 0 0 1 0 mm
Chris@259 594 else entries 0 0 0 1 mm fi);
Chris@259 595
Chris@178 596 checkDimensionsFor direction first mm =
Chris@178 597 (counter = if direction == Horizontal () then (.rows) else (.columns) fi;
Chris@208 598 n = counter (size first);
Chris@208 599 if not (all id (map do m: counter (size m) == n done mm)) then
Chris@208 600 failWith "Matrix dimensions incompatible for concat (found \(map do m: counter (size m) done mm) not all of which are \(n))";
Chris@178 601 fi);
Chris@178 602
Chris@187 603 concat direction mm = //!!! doc: storage order is taken from first matrix in sequence
Chris@259 604 case length mm of
Chris@259 605 0: zeroSizeMatrix ();
Chris@259 606 1: head mm;
Chris@259 607 _:
Chris@259 608 first = head mm;
Chris@178 609 checkDimensionsFor direction first mm;
Chris@259 610 if all isSparse? mm then
Chris@259 611 sparseConcat direction first mm
Chris@259 612 else
Chris@259 613 row = isRowMajor? first;
Chris@259 614 // horizontal, row-major: against grain with rows
Chris@259 615 // horizontal, col-major: with grain with cols
Chris@259 616 // vertical, row-major: with grain with rows
Chris@259 617 // vertical, col-major: against grain with cols
Chris@259 618 case direction of
Chris@259 619 Horizontal ():
Chris@259 620 if row then concatAgainstGrain DenseRows getRow (.rows) mm;
Chris@259 621 else concatWithGrain DenseCols getColumn (.columns) mm;
Chris@259 622 fi;
Chris@259 623 Vertical ():
Chris@259 624 if row then concatWithGrain DenseRows getRow (.rows) mm;
Chris@259 625 else concatAgainstGrain DenseCols getColumn (.columns) mm;
Chris@259 626 fi;
Chris@259 627 esac;
Chris@259 628 fi;
Chris@190 629 esac;
Chris@177 630
Chris@260 631 //!!! doc note: argument order chosen for consistency with std module slice
Chris@260 632 rowSlice m start end = //!!! doc: storage order same as input
Chris@208 633 if isRowMajor? m then
Chris@260 634 DenseRows (array (map ((flip getRow) m) [start .. end - 1]))
Chris@187 635 else
Chris@260 636 DenseCols (array (map do v: vec.slice v start end done (asColumns m)))
Chris@187 637 fi;
Chris@187 638
Chris@260 639 //!!! doc note: argument order chosen for consistency with std module slice
Chris@260 640 columnSlice m start end = //!!! doc: storage order same as input
Chris@208 641 if not isRowMajor? m then
Chris@260 642 DenseCols (array (map ((flip getColumn) m) [start .. end - 1]))
Chris@187 643 else
Chris@260 644 DenseRows (array (map do v: vec.slice v start end done (asRows m)))
Chris@187 645 fi;
Chris@187 646
Chris@201 647 resizedTo newsize m =
Chris@208 648 (if newsize == (size m) then
Chris@201 649 m
Chris@261 650 elif isSparse? m then
Chris@261 651 // don't call makeSparse directly: want to discard
Chris@261 652 // out-of-range cells
Chris@261 653 newSparseMatrix (typeOf m) newsize (enumerateSparse m)
Chris@208 654 elif (height m) == 0 or (width m) == 0 then
Chris@202 655 zeroMatrixWithTypeOf m newsize;
Chris@201 656 else
Chris@208 657 growrows = newsize.rows - (height m);
Chris@208 658 growcols = newsize.columns - (width m);
Chris@208 659 rowm = isRowMajor? m;
Chris@201 660 resizedTo newsize
Chris@201 661 if rowm and growrows < 0 then
Chris@260 662 rowSlice m 0 newsize.rows
Chris@201 663 elif (not rowm) and growcols < 0 then
Chris@260 664 columnSlice m 0 newsize.columns
Chris@201 665 elif growrows < 0 then
Chris@260 666 rowSlice m 0 newsize.rows
Chris@201 667 elif growcols < 0 then
Chris@260 668 columnSlice m 0 newsize.columns
Chris@201 669 else
Chris@201 670 if growrows > 0 then
Chris@201 671 concat (Vertical ())
Chris@208 672 [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
Chris@201 673 else
Chris@201 674 concat (Horizontal ())
Chris@208 675 [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
Chris@201 676 fi
Chris@201 677 fi
Chris@202 678 fi);
Chris@201 679
Chris@5 680 {
Chris@208 681 size,
Chris@208 682 width,
Chris@208 683 height,
Chris@246 684 density,
Chris@249 685 nonZeroValues,
Chris@260 686 at = at',
Chris@208 687 getColumn,
Chris@208 688 getRow,
Chris@208 689 isRowMajor?,
Chris@234 690 isSparse?,
Chris@208 691 generate,
Chris@208 692 constMatrix,
Chris@208 693 randomMatrix,
Chris@208 694 zeroMatrix,
Chris@208 695 identityMatrix,
Chris@208 696 zeroSizeMatrix,
Chris@208 697 equal,
Chris@226 698 equalUnder,
Chris@208 699 transposed,
Chris@208 700 flipped,
Chris@208 701 toRowMajor,
Chris@208 702 toColumnMajor,
Chris@238 703 toSparse,
Chris@238 704 toDense,
Chris@208 705 scaled,
Chris@208 706 resizedTo,
Chris@208 707 asRows,
Chris@208 708 asColumns,
Chris@208 709 sum = sum',
Chris@229 710 difference,
Chris@229 711 abs = abs',
Chris@258 712 filter,
Chris@208 713 product,
Chris@208 714 concat,
Chris@208 715 rowSlice,
Chris@208 716 columnSlice,
Chris@208 717 newMatrix,
Chris@208 718 newRowVector,
Chris@208 719 newColumnVector,
Chris@261 720 newSparseMatrix,
Chris@255 721 enumerate
Chris@208 722 }
Chris@208 723 as
Chris@208 724 {
Chris@208 725 //!!! check whether these are right to be .selector rather than just selector
Chris@208 726
Chris@208 727 size is matrix -> { .rows is number, .columns is number },
Chris@208 728 width is matrix -> number,
Chris@208 729 height is matrix -> number,
Chris@246 730 density is matrix -> number,
Chris@249 731 nonZeroValues is matrix -> number,
Chris@260 732 at is matrix -> number -> number -> number,
Chris@214 733 getColumn is number -> matrix -> vector,
Chris@214 734 getRow is number -> matrix -> vector,
Chris@208 735 isRowMajor? is matrix -> boolean,
Chris@234 736 isSparse? is matrix -> boolean,
Chris@195 737 generate is (number -> number -> number) -> { .rows is number, .columns is number } -> matrix,
Chris@195 738 constMatrix is number -> { .rows is number, .columns is number } -> matrix,
Chris@195 739 randomMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 740 zeroMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 741 identityMatrix is { .rows is number, .columns is number } -> matrix,
Chris@195 742 zeroSizeMatrix is () -> matrix,
Chris@195 743 equal is matrix -> matrix -> boolean,
Chris@226 744 equalUnder is (number -> number -> boolean) -> matrix -> matrix -> boolean,
Chris@195 745 transposed is matrix -> matrix,
Chris@195 746 flipped is matrix -> matrix,
Chris@195 747 toRowMajor is matrix -> matrix,
Chris@195 748 toColumnMajor is matrix -> matrix,
Chris@243 749 toSparse is matrix -> matrix,
Chris@238 750 toDense is matrix -> matrix,
Chris@195 751 scaled is number -> matrix -> matrix,
Chris@243 752 thresholded is number -> matrix -> matrix,
Chris@195 753 resizedTo is { .rows is number, .columns is number } -> matrix -> matrix,
Chris@214 754 asRows is matrix -> list<vector>,
Chris@214 755 asColumns is matrix -> list<vector>,
Chris@208 756 sum is matrix -> matrix -> matrix,
Chris@229 757 difference is matrix -> matrix -> matrix,
Chris@229 758 abs is matrix -> matrix,
Chris@258 759 filter is (number -> boolean) -> matrix -> matrix,
Chris@195 760 product is matrix -> matrix -> matrix,
Chris@195 761 concat is (Horizontal () | Vertical ()) -> list<matrix> -> matrix,
Chris@260 762 rowSlice is matrix -> number -> number -> matrix,
Chris@260 763 columnSlice is matrix -> number -> number -> matrix,
Chris@214 764 newMatrix is (ColumnMajor () | RowMajor ()) -> list<vector> -> matrix,
Chris@214 765 newRowVector is vector -> matrix,
Chris@214 766 newColumnVector is vector -> matrix,
Chris@255 767 newSparseMatrix is (ColumnMajor () | RowMajor ()) -> { .rows is number, .columns is number } -> list<{ .i is number, .j is number, .v is number }> -> matrix,
Chris@255 768 enumerate is matrix -> list<{ .i is number, .j is number, .v is number }>
Chris@5 769 }
Chris@5 770