To check out this repository please hg clone the following URL, or open the URL using EasyMercurial or your preferred Mercurial client.

Statistics Download as Zip
| Branch: | Revision:

root / src / may / matrix.yeti @ 597:c62894d056c7

History | View | Annotate | Download (36.7 KB)

1

    
2
/**
3
 * Matrices. A matrix is a two-dimensional (NxM) container of
4
 * double-precision floating point values.
5
 *
6
 * A matrix may be dense or sparse.
7
 * 
8
 * A dense matrix (the default) is just a series of vectors, making up
9
 * the matrix "grid". The values may be stored in either column-major
10
 * order, in which case the series consists of one vector for each
11
 * column in the matrix, or row-major order, in which case the series
12
 * consists of one vector for each row. The default is column-major.
13
 * 
14
 * A sparse matrix has a more complex representation in which only the
15
 * non-zero values are stored. This is typically used for matrices
16
 * containing sparse data, that is, data in which most of the values
17
 * are zero: using a sparse representation is more efficient than a
18
 * dense one (in both time and memory) if the matrix is very large but
19
 * contains a relatively low proportion of non-zero values. Like dense
20
 * matrices, sparse ones may be column-major or row-major.
21
 * 
22
 * The choice of dense or sparse, row- or column-major is a question
23
 * of efficiency alone. All functions in this module should return the
24
 * same results regardless of how the matrices they operate on are
25
 * represented. However, differences in performance can be very large
26
 * and it is often worth converting matrices to a different storage
27
 * format if you know they can be more efficiently manipulated that
28
 * way. For example, multiplying two matrices is fastest if the first
29
 * is in column-major and the second in row-major order.
30
 * 
31
 * Use the isRowMajor? and isSparse? functions to query the storage
32
 * format of a matrix; use the flipped function to convert between
33
 * column-major and row-major storage; and use toSparse and toDense to
34
 * convert between sparse and dense storage.
35
 *
36
 * Note that the matrix size is preserved even if at least one
37
 * dimension is zero. That is, it is legal to have matrices of size
38
 * 0x0, 0x4, 1x0 etc, and they are distinct from each other.
39
 */
40

    
41
module may.matrix;
42

    
43
{ ceil, floor, random } = load may.mathmisc;
44

    
45
vec = load may.vector;
46

    
47
load yeti.json;
48

    
49
typedef opaque matrix_t = {
50
    size is { rows is number, columns is number },
51
    data is
52
        DenseRows array<vec.vector_t> | // array of rows
53
        DenseCols array<vec.vector_t> | // array of columns
54
        SparseCSR {
55
            values is vec.vector_t,
56
            indices is array<number>, // column index of each value
57
            pointers is array<number>, // offset of first value in each row
58
            extent is number // max possible index + 1, i.e. number of columns
59
        } |
60
        SparseCSC {
61
            values is vec.vector_t,
62
            indices is array<number>, // row index of each value
63
            pointers is array<number>, // offset of first value in each column
64
            extent is number // max pointers index + 1, i.e. number of rows
65
        }
66
};
67

    
68
size m = m.size;
69
width m = m.size.columns;
70
height m = m.size.rows;
71

    
72
nonZeroValues m =
73
   (nz d =
74
        sum
75
           (map do v:
76
                sum (map do n: if n == 0 then 0 else 1 fi done (vec.list v))
77
                done d);
78
    case m.data of 
79
    DenseRows d: nz d;
80
    DenseCols d: nz d;
81
    SparseCSR d: vec.length d.values;
82
    SparseCSC d: vec.length d.values;
83
    esac);
84

    
85
density m =
86
   ({ rows, columns } = size m;
87
    cells = rows * columns;
88
    (nonZeroValues m) / cells);
89

    
90
sparseSlice n d =
91
   (start = d.pointers[n];
92
    end = d.pointers[n+1];
93
    { 
94
        values = vec.slice d.values start end,
95
        indices = slice d.indices start end,
96
    });
97

    
98
nonEmptySlices d =
99
   (ne = array [];
100
    for [0..length d.pointers - 2] do i:
101
        if d.pointers[i] != d.pointers[i+1] then
102
            push ne i
103
        fi
104
    done;
105
    ne);
106

    
107
fromSlice n m d =
108
   (slice = sparseSlice n d;
109
    var v = 0;
110
    for [0..length slice.indices - 1] do i:
111
        if slice.indices[i] == m then
112
            v := vec.at slice.values i;
113
        fi
114
    done;
115
    v);
116

    
117
filledSlice n d =
118
   (slice = sparseSlice n d;
119
    dslice = new double[d.extent];
120
    for [0..length slice.indices - 1] do i:
121
        dslice[slice.indices[i]] := vec.at slice.values i;
122
    done;
123
    vec.vector dslice);
124

    
125
at' m row col =
126
    case m.data of
127
    DenseRows rows: r = rows[row]; vec.at r col;
128
    DenseCols cols: c = cols[col]; vec.at c row;
129
    SparseCSR data: fromSlice row col data;
130
    SparseCSC data: fromSlice col row data;
131
    esac;
132

    
133
//!!! better as getXx or just xx?
134

    
135
//!!! arguably getRow, getColumn, getDiagonal should have m as first arg for symmetry with at
136
getColumn j m =
137
    case m.data of
138
    DenseCols cols: cols[j];
139
    SparseCSC data: filledSlice j data;
140
    _: vec.fromList (map do i: at' m i j done [0..height m - 1]);
141
    esac;
142

    
143
getRow i m =
144
    case m.data of
145
    DenseRows rows: rows[i];
146
    SparseCSR data: filledSlice i data; 
147
    _: vec.fromList (map do j: at' m i j done [0..width m - 1]);
148
    esac;
149

    
150
getDiagonal k m =
151
   (ioff = if k < 0 then -k else 0 fi;
152
    joff = if k > 0 then  k else 0 fi;
153
    n = min (width m - joff) (height m - ioff);
154
    vec.fromList (map do i: at' m (i + ioff) (i + joff) done [0..n - 1]));
155

    
156
asRows m =
157
    map do i: getRow i m done [0 .. (height m) - 1];
158

    
159
asColumns m =
160
    map do i: getColumn i m done [0 .. (width m) - 1];
161

    
162
isRowMajor? m =
163
    case m.data of
164
    DenseRows _: true;
165
    DenseCols _: false;
166
    SparseCSR _: true;
167
    SparseCSC _: false;
168
    esac;
169

    
170
isSparse? m =
171
    case m.data of
172
    DenseRows _: false;
173
    DenseCols _: false;
174
    SparseCSR _: true;
175
    SparseCSC _: true;
176
    esac;
177

    
178
taggerForTypeOf m =
179
    if isRowMajor? m then Rows
180
    else Columns
181
    fi;
182

    
183
taggerForFlippedTypeOf m =
184
    if isRowMajor? m then Columns
185
    else Rows
186
    fi;
187

    
188
flippedSize { rows, columns } = { rows = columns, columns = rows };
189

    
190
newColumnMajorStorage { rows, columns } = 
191
    array (map \(vec.zeros rows) [1..columns]);
192

    
193
zeroMatrix size = 
194
    {
195
        size,
196
        data = DenseCols (newColumnMajorStorage size)
197
    };
198

    
199
zeroMatrixWithTypeOf m size = 
200
    {
201
        size,
202
        data =
203
            if isRowMajor? m then
204
                DenseRows (newColumnMajorStorage (flippedSize size));
205
            else
206
                DenseCols (newColumnMajorStorage size);
207
            fi
208
    };
209

    
210
zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 };
211

    
212
newMatrix size d =
213
    case d of
214
    Rows rr: 
215
        if (length rr) != size.rows then
216
            failWith "Wrong number of rows in row-major newMatrix (\(length rr), size calls for \(size.rows))";
217
        elif not (all do r: vec.length r == size.columns done rr) then
218
            failWith "Wrong or inconsistent number of columns in rows in row-major newMatrix (\(map vec.length rr)), size calls for \(size.columns))";
219
        else
220
            {
221
                size,
222
                data = DenseRows (array rr)
223
            }
224
        fi;
225
    Columns cc: 
226
        if (length cc) != size.columns then
227
            failWith "Wrong number of columns in column-major newMatrix (\(length cc), size calls for \(size.columns))";
228
        elif not (all do c: vec.length c == size.rows done cc) then
229
            failWith "Wrong or inconsistent number of rows in in columns in column-major newMatrix (\(map vec.length cc)), size calls for \(size.rows))";
230
        else
231
            {
232
                size,
233
                data = DenseCols (array cc)
234
            }
235
        fi;
236
    esac;
237

    
238
newMatrixMatching m d = newMatrix (size m) ((taggerForTypeOf m) d);
239

    
240
generate f { rows, columns } =
241
   (m = array (map \(new double[rows]) [1..columns]);
242
    for [0..columns-1] do col:
243
        for [0..rows-1] do row:
244
            m[col][row] := f row col;
245
        done;
246
    done;
247
    {
248
        size = { rows, columns },
249
        data = DenseCols (array (map vec.vector m))
250
    });
251

    
252
swapij =
253
    map do { i, j, v }: { i = j, j = i, v } done;
254

    
255
//!!! should use { row = , column = , value = } instead of i, j, v?
256
enumerateSparse m =
257
   (enumerate { values, indices, pointers } =
258
        concat
259
           (map do i:
260
                start = pointers[i];
261
                end = pointers[i+1];
262
                map2 do j v: { i, j, v } done 
263
                    (slice indices start end)
264
                    (vec.list (vec.slice values start end))
265
                done [0..length pointers - 2]);
266
    case m.data of
267
    SparseCSC d: swapij (enumerate d);
268
    SparseCSR d: enumerate d;
269
     _: [];
270
    esac);
271

    
272
enumerateDense m =
273
   (enumerate d =
274
        concat
275
           (map do i:
276
                vv = d[i];
277
                map2 do j v: { i, j, v } done
278
                    [0..vec.length vv - 1]
279
                    (vec.list vv);
280
                done [0..length d - 1]);
281
    case m.data of
282
    DenseCols c: swapij (enumerate c);
283
    DenseRows r: enumerate r;
284
     _: [];
285
    esac);
286

    
287
enumerate m =
288
    if isSparse? m then enumerateSparse m else enumerateDense m fi;
289

    
290
// Make a sparse matrix from entries whose i, j values are known to be
291
// within range
292
newSparse size d =
293
   (isRow = case d of Rows _: true; Columns _: false esac;
294
    data = case d of Rows rr: rr; Columns cc: cc esac;
295
    ordered = 
296
        sortBy do a b:
297
            if a.maj == b.maj then a.min < b.min else a.maj < b.maj fi
298
        done
299
           (map
300
                if isRow then
301
                    do { i, j, v }: { maj = i, min = j, v } done;
302
                else
303
                    do { i, j, v }: { maj = j, min = i, v } done;
304
                fi
305
               (filter do d: d.v != 0 done data));
306
    tagger = if isRow then SparseCSR else SparseCSC fi;
307
    majorSize = if isRow then size.rows else size.columns fi;
308
    minorSize = if isRow then size.columns else size.rows fi;
309
    pointers = array [0];
310
    setArrayCapacity pointers (size.rows + 1);
311
    fillPointers n i data =
312
        if n < majorSize then
313
            case data of
314
            d::rest:
315
               (for [n..d-1] \(push pointers i);
316
                fillPointers d (i+1) rest);
317
             _:
318
                for [n..majorSize-1] \(push pointers i);
319
            esac;
320
        fi;
321
    fillPointers 0 0 (map (.maj) ordered);
322
    {
323
        size,
324
        data = tagger {
325
            values = vec.fromList (map (.v) ordered),
326
            indices = array (map (.min) ordered),
327
            pointers,
328
            extent = minorSize,
329
        }
330
    });
331

    
332
newSparseMatching m d = newSparse (size m) ((taggerForTypeOf m) d);
333

    
334
// Make a sparse matrix from entries that may contain out-of-range
335
// cells which need to be filtered out. This is the public API for
336
// newSparse and is also used to discard out-of-range cells from
337
// resizedTo.
338
//!!! doc: i is row number, j is column number (throughout, for sparse stuff). Would calling them row/column be better?
339
//!!! doc: Rows/Columns determines the storage order, the input data are treated the same either way (perhaps this does mean row/column would be better than i/j)
340
newSparseMatrix size d =
341
   (tagger = case d of Rows _: Rows; Columns _: Columns esac;
342
    data = case d of Rows rr: rr; Columns cc: cc esac;
343
    data = filter
344
        do { i, j, v }:
345
            i == int i and i >= 0 and i < size.rows and 
346
            j == int j and j >= 0 and j < size.columns
347
        done data;
348
    newSparse size (tagger data));
349

    
350
toSparse m =
351
    if isSparse? m then m
352
    else newSparseMatching m (enumerateDense m);
353
    fi;
354

    
355
toDense m =
356
    {
357
        size = (size m),
358
        data = 
359
            if not (isSparse? m) then m.data
360
            elif isRowMajor? m then
361
                DenseRows (array (map do row: getRow row m done [0..height m - 1]));
362
            else
363
                DenseCols (array (map do col: getColumn col m done [0..width m - 1]));
364
            fi
365
    };
366

    
367
constMatrix n = generate do row col: n done;
368
randomMatrix = generate do row col: random () done;
369
identityMatrix = constMatrix 1;
370

    
371
transposed m =
372
    {
373
        size = flippedSize (size m),
374
        data = 
375
            case m.data of
376
            DenseRows d: DenseCols d;
377
            DenseCols d: DenseRows d;
378
            SparseCSR d: SparseCSC d;
379
            SparseCSC d: SparseCSR d;
380
            esac
381
    };
382

    
383
flipped m =
384
    if isSparse? m then
385
        newSparse (size m) ((taggerForFlippedTypeOf m) (enumerateSparse m))
386
    else
387
        if isRowMajor? m then
388
            generate do row col: at' m row col done (size m);
389
        else
390
            transposed
391
               (generate do row col: at' m col row done (flippedSize (size m)));
392
        fi
393
    fi;
394

    
395
toRowMajor m =
396
    if isRowMajor? m then m else flipped m fi;
397

    
398
toColumnMajor m =
399
    if not isRowMajor? m then m else flipped m fi;
400

    
401
equal'' comparator vecComparator m1 m2 =
402
    // Prerequisite: m1 and m2 have same sparse-p and storage order
403
   (compareVecLists vv1 vv2 = all id (map2 vecComparator vv1 vv2);
404
    compareSparse d1 d2 =
405
        d1.extent == d2.extent and
406
        vecComparator d1.values d2.values and
407
        d1.indices == d2.indices and
408
        d1.pointers == d2.pointers;
409
    case m1.data of
410
    DenseRows d1:
411
        case m2.data of DenseRows d2: compareVecLists d1 d2; _: false; esac;
412
    DenseCols d1:
413
        case m2.data of DenseCols d2: compareVecLists d1 d2; _: false; esac;
414
    SparseCSR d1:
415
        case m2.data of SparseCSR d2: compareSparse d1 d2; _: false; esac;
416
    SparseCSC d1:
417
        case m2.data of SparseCSC d2: compareSparse d1 d2; _: false; esac;
418
    esac);
419

    
420
equal' comparator vecComparator m1 m2 =
421
    if size m1 != size m2 then 
422
        false
423
    elif isRowMajor? m1 != isRowMajor? m2 then
424
        equal' comparator vecComparator (flipped m1) m2;
425
    elif isSparse? m1 != isSparse? m2 then
426
        if isSparse? m1 then
427
            equal' comparator vecComparator m1 (toSparse m2)
428
        else
429
            equal' comparator vecComparator (toSparse m1) m2
430
        fi
431
    else
432
        equal'' comparator vecComparator m1 m2
433
    fi;
434

    
435
// Compare matrices using the given comparator for individual cells.
436
// Note that matrices with different storage order but the same
437
// contents are equal, although comparing them is slow.
438
//!!! Document the fact that sparse matrices can only be equal if they
439
// have the same set of non-zero cells (regardless of comparator used)
440
equalUnder comparator =
441
    equal' comparator (vec.equalUnder comparator);
442

    
443
equal =
444
    equal' (==) vec.equal;
445

    
446
fromRows rows =
447
   (if any do r: vec.length r != vec.length (head rows) done rows then
448
        failWith "Inconsistent row lengths in fromRows (\(map vec.length rows))";
449
    fi;
450
    {
451
        size = { 
452
            rows = length rows, 
453
            columns = 
454
                if empty? rows then 0
455
                else vec.length (head rows) 
456
                fi,
457
        },
458
        data = DenseRows (array rows)
459
    });
460

    
461
fromColumns cols =
462
   (if any do c: vec.length c != vec.length (head cols) done cols then
463
        failWith "Inconsistent column lengths in fromColumns (\(map vec.length cols))";
464
    fi;
465
    {
466
        size = { 
467
            columns = length cols, 
468
            rows = 
469
                if empty? cols then 0
470
                else vec.length (head cols) 
471
                fi,
472
        },
473
        data = DenseCols (array cols)
474
    });
475

    
476
fromLists data = 
477
    case data of
478
    Rows rr: fromRows (map vec.fromList rr);
479
    Columns cc: fromColumns (map vec.fromList cc);
480
    esac;
481

    
482
newRowVector data = //!!! NB does not copy data
483
    fromRows (array [data]);
484

    
485
newColumnVector data = //!!! NB does not copy data
486
    fromColumns (array [data]);
487

    
488
denseLinearOp op m1 m2 =
489
    if isRowMajor? m1 then
490
        newMatrixMatching m1
491
           (map2 do c1 c2: op c1 c2 done (asRows m1) (asRows m2));
492
    else
493
        newMatrixMatching m1
494
           (map2 do c1 c2: op c1 c2 done (asColumns m1) (asColumns m2));
495
    fi;
496

    
497
sparseSumOrDifference op m1 m2 =
498
   (h = [:];
499
    for (enumerate m1) do { i, j, v }:
500
        if not (i in h) then h[i] := [:] fi;
501
        h[i][j] := v;
502
    done;
503
    for (enumerate m2) do { i, j, v }:
504
        if not (i in h) then h[i] := [:] fi;
505
        if j in h[i] then h[i][j] := op h[i][j] v;
506
        else h[i][j] := op 0 v;
507
        fi;
508
    done;
509
    entries = concat
510
       (map do i:
511
            kk = keys h[i];
512
            map2 do j v: { i, j, v } done kk (map (at h[i]) kk)
513
            done (keys h));
514
    newSparseMatching m1 entries);
515

    
516
sum' mm =
517
    case mm of
518
    m1::m2::rest:
519
        sum' 
520
           (if (size m1) != (size m2)
521
            then failWith "Matrices are not the same size: \(size m1), \(size m2)";
522
            elif isSparse? m1 and isSparse? m2 then
523
                sparseSumOrDifference (+) m1 m2;
524
            else
525
                add2 v1 v2 = vec.add [v1,v2];
526
                denseLinearOp add2 m1 m2;
527
            fi :: rest);
528
    [m1]: m1;
529
    _: failWith "Empty argument list";
530
    esac;
531
    
532
difference m1 m2 =
533
    if (size m1) != (size m2)
534
    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
535
    elif isSparse? m1 and isSparse? m2 then
536
        sparseSumOrDifference (-) m1 m2;
537
    else
538
        denseLinearOp vec.subtract m1 m2;
539
    fi;
540

    
541
scaled factor m =
542
    if isSparse? m then
543
        newSparseMatching m
544
           (map do { i, j, v }: { i, j, v = factor * v } done (enumerate m))
545
    elif isRowMajor? m then
546
        newMatrixMatching m (map (vec.scaled factor) (asRows m));
547
    else
548
        newMatrixMatching m (map (vec.scaled factor) (asColumns m));
549
    fi;
550

    
551
abs' m =
552
    if isSparse? m then
553
        newSparseMatching m
554
           (map do { i, j, v }: { i, j, v = abs v } done (enumerate m))
555
    elif isRowMajor? m then
556
        newMatrixMatching m (map vec.abs (asRows m));
557
    else
558
        newMatrixMatching m (map vec.abs (asColumns m));
559
    fi;
560

    
561
negative m =
562
    if isSparse? m then
563
        newSparseMatching m
564
           (map do { i, j, v }: { i, j, v = (-v) } done (enumerate m))
565
    elif isRowMajor? m then
566
        newMatrixMatching m (map vec.negative (asRows m));
567
    else
568
        newMatrixMatching m (map vec.negative (asColumns m));
569
    fi;
570

    
571
//!!! doc: filter by predicate, always returns sparse matrix
572
filter' f m =
573
    newSparseMatching m
574
       (map do { i, j, v }: { i, j, v = if f v then v else 0 fi } done
575
           (enumerate m));
576

    
577
any' f m =
578
    any f (map (.v) (enumerate m));
579

    
580
all' f m =
581
    all f (map (.v) (enumerate m));
582

    
583
sparseProductLeft size m1 m2 =
584
   ({ values, indices, pointers } = 
585
        case m1.data of
586
        SparseCSR d: d;
587
        SparseCSC d: d;
588
        _: failWith "sparseProductLeft called for non-sparse m1";
589
        esac;
590
    rows = isRowMajor? m1;
591
    data = array (map \(new double[size.rows]) [1..size.columns]);
592
    for [0..size.columns - 1] do j':
593
        c = getColumn j' m2;
594
        var p = 0;
595
        for [0..length indices - 1] do ix:
596
            ix == pointers[p+1] loop (p := p + 1);
597
            i = if rows then p else indices[ix] fi;
598
            j = if rows then indices[ix] else p fi;
599
            data[j'][i] := data[j'][i] + (vec.at values ix) * (vec.at c j);
600
        done;
601
    done;
602
    newMatrix size (Columns (array (map vec.vector (list data)))));
603

    
604
sparseProductRight size m1 m2 =
605
   ({ values, indices, pointers } = 
606
        case m2.data of
607
        SparseCSR d: d;
608
        SparseCSC d: d;
609
        _: failWith "sparseProductLeft called for non-sparse m1";
610
        esac;
611
    rows = isRowMajor? m2;
612
    data = array (map \(new double[size.columns]) [1..size.rows]);
613
    for [0..size.rows - 1] do i':
614
        r = getRow i' m1;
615
        var p = 0;
616
        for [0..length indices - 1] do ix:
617
            ix == pointers[p+1] loop (p := p + 1);
618
            i = if rows then p else indices[ix] fi;
619
            j = if rows then indices[ix] else p fi;
620
            data[i'][j] := data[i'][j] + (vec.at values ix) * (vec.at r i);
621
        done;
622
    done;
623
    newMatrix size (Rows (array (map vec.vector (list data)))));
624

    
625
sparseProduct size m1 m2 =
626
    case m2.data of
627
    SparseCSC d:
628
       ({ values, indices, pointers } =
629
            case m1.data of
630
            SparseCSR d1: d1;
631
            SparseCSC d1: d1;
632
            _: failWith "sparseProduct called for non-sparse matrices";
633
            esac;
634
        rows = isRowMajor? m1;
635
        var p = 0;
636
        pindices = new int[length indices];
637
        for [0..length indices - 1] do ix:
638
            ix == pointers[p+1] loop (p := p + 1);
639
            pindices[ix] := p;
640
        done;
641
        entries =
642
           (map do j':
643
                cs = sparseSlice j' d;
644
                hin = mapIntoHash
645
                   (at cs.indices) (vec.at cs.values)
646
                   [0..length cs.indices - 1];
647
                hout = [:];
648
                for [0..length indices - 1] do ix:
649
                    i = if rows then pindices[ix] else indices[ix] fi;
650
                    j = if rows then indices[ix] else pindices[ix] fi;
651
                    if j in hin then
652
                        p = (vec.at values ix) * hin[j];
653
                        hout[i] := p + (if i in hout then hout[i] else 0 fi);
654
                    fi;
655
                done;
656
                map do i:
657
                    { i, j = j', v = hout[i] }
658
                done (keys hout);
659
            done (nonEmptySlices d));
660
        newSparse size (Columns (concat entries)));
661
    SparseCSR _:
662
        sparseProduct size m1 (flipped m2);
663
     _: failWith "sparseProduct called for non-sparse matrices";
664
    esac;
665

    
666
denseProduct size m1 m2 =
667
   (data = array (map \(new double[size.rows]) [1..size.columns]);
668
    for [0..size.rows - 1] do i:
669
        row = getRow i m1;
670
        for [0..size.columns - 1] do j:
671
            data[j][i] := vec.sum (vec.multiply [row, getColumn j m2]);
672
        done;
673
    done;
674
    newMatrix size (Columns (array (map vec.vector (list data)))));
675

    
676
product m1 m2 =
677
    if (size m1).columns != (size m2).rows
678
    then failWith "Matrix dimensions incompatible: \(size m1), \(size m2) (\((size m1).columns) != \((size m2).rows))";
679
    else 
680
        size = { rows = (size m1).rows, columns = (size m2).columns };
681
        if isSparse? m1 then
682
            if isSparse? m2 then
683
                sparseProduct size m1 m2
684
            else
685
                sparseProductLeft size m1 m2
686
            fi
687
        elif isSparse? m2 then
688
            sparseProductRight size m1 m2
689
        else
690
            denseProduct size m1 m2
691
        fi;
692
    fi;
693

    
694
entryWiseProduct mm =
695
    case mm of
696
    m1::m2::rest:
697
        entryWiseProduct
698
           (if (size m1) != (size m2)
699
            then failWith "Matrices are not the same size: \(size m1), \(size m2)";
700
            else 
701
                if isSparse? m1 then
702
                    newSparse (size m1)
703
                       ((taggerForTypeOf m1)
704
                           (map do { i, j, v }: { i, j, v = v * (at' m2 i j) } done
705
                               (enumerateSparse m1)))
706
                elif isSparse? m2 then
707
                    entryWiseProduct (m2::m1::rest)
708
                else
709
                    if isRowMajor? m1 then
710
                        fromRows (array (map2 do v1 v2: vec.multiply [v1,v2] done
711
                           (asRows m1) (asRows m2)));
712
                    else
713
                        fromColumns (array (map2 do v1 v2: vec.multiply [v1,v2] done
714
                           (asColumns m1) (asColumns m2)));
715
                    fi
716
                fi
717
            fi :: rest);
718
    [m1]: m1;
719
    _: failWith "Empty argument list";
720
    esac;
721

    
722
entryWiseDivide m1 m2 =
723
    if (size m1) != (size m2)
724
    then failWith "Matrices are not the same size: \(size m1), \(size m2)";
725
    else 
726
        if isSparse? m1 then
727
            newSparse (size m1)
728
               ((taggerForTypeOf m1)
729
                   (map do { i, j, v }: { i, j, v = v / (at' m2 i j) } done
730
                       (enumerateSparse m1)))
731
        // For m2 to be sparse makes no sense (divide by zero all over
732
        // the shop).
733
        else
734
            if isRowMajor? m1 then
735
                fromRows (array (map2 vec.divide (asRows m1) (asRows m2)));
736
            else
737
                fromColumns (array (map2 vec.divide (asColumns m1) (asColumns m2)));
738
            fi
739
        fi
740
    fi;
741

    
742
concatAgainstGrain tagger getter counter mm =
743
   (n = counter (size (head mm));
744
    tagger (array
745
       (map do i:
746
           vec.concat (map (getter i) mm)
747
           done [0..n-1])));
748

    
749
concatWithGrain tagger getter counter mm =
750
    tagger (array
751
       (concatMap do m:
752
           n = counter (size m);
753
           map do i: getter i m done [0..n-1]
754
        done mm));
755

    
756
sparseConcat direction first mm =
757
   (dimension d f = if direction == d then sum (map f mm) else f first fi;
758
    rows = dimension (Vertical ()) height;
759
    columns = dimension (Horizontal ()) width;
760
    entries ioff joff ui uj mm acc =
761
        case mm of 
762
        m::rest:
763
            entries
764
               (ioff + ui * height m)
765
               (joff + uj * width m)
766
                ui uj rest
767
               ((map do { i, j, v }: { i = i + ioff, j = j + joff, v }
768
                 done (enumerate m)) ++ acc);
769
         _: acc;
770
        esac;
771
    newSparse { rows, columns }
772
       ((taggerForTypeOf first)
773
           (if direction == Vertical () then entries 0 0 1 0 mm []
774
            else entries 0 0 0 1 mm [] fi)));
775

    
776
sumDimensions sumCounter checkCounter mm =
777
   (check = checkCounter (size (head mm));
778
    sum
779
       (map do m:
780
            s = size m;
781
            if (checkCounter s) != check then
782
                failWith "Matrix dimensions incompatible for concat (found \(map do m: checkCounter (size m) done mm) not all of which are \(check))";
783
            else
784
                sumCounter s;
785
            fi
786
        done mm));
787

    
788
concatHorizontal mm = //!!! doc: storage order is taken from first matrix in sequence; concat is obviously not lazy (unlike std module)
789
    case mm of
790
    [m]: m;
791
    first::rest:
792
       (w = sumDimensions (.columns) (.rows) mm;
793
        if all isSparse? mm then
794
            sparseConcat (Horizontal ()) first mm
795
        else
796
            row = isRowMajor? first;
797
            {
798
                size = { rows = height first, columns = w },
799
                data =
800
                    // horizontal, row-major: against grain with rows
801
                    // horizontal, col-major: with grain with cols
802
                    if row then concatAgainstGrain DenseRows getRow (.rows) mm;
803
                    else concatWithGrain DenseCols getColumn (.columns) mm;
804
                    fi
805
            };
806
        fi);
807
     _: zeroSizeMatrix ();
808
    esac;
809

    
810
concatVertical mm = //!!! doc: storage order is taken from first matrix in sequence; concat is obviously not lazy (unlike std module)
811
    case mm of
812
    [m]: m;
813
    first::rest:
814
       (h = sumDimensions (.rows) (.columns) mm;
815
        if all isSparse? mm then
816
            sparseConcat (Vertical ()) first mm
817
        else
818
            row = isRowMajor? first;
819
            {
820
                size = { rows = h, columns = width first },
821
                data = 
822
                    // vertical, row-major: with grain with rows
823
                    // vertical, col-major: against grain with cols
824
                    if row then concatWithGrain DenseRows getRow (.rows) mm;
825
                    else concatAgainstGrain DenseCols getColumn (.columns) mm;
826
                    fi,
827
            };
828
        fi);
829
     _: zeroSizeMatrix ();
830
    esac;
831

    
832
//!!! next two v. clumsy
833

    
834
//!!! doc note: argument order chosen for consistency with std module slice
835
//!!! NB always returns dense matrix, should have sparse version
836
rowSlice m start end = //!!! doc: storage order same as input
837
    if start < 0 then rowSlice m 0 end
838
    elif start > height m then rowSlice m (height m) end
839
    else
840
        if end < start then rowSlice m start start
841
        elif end > height m then rowSlice m start (height m)
842
        else
843
            if isRowMajor? m then
844
                newMatrix { rows = end - start, columns = width m }
845
                   (Rows
846
                       (array (map ((flip getRow) m) [start .. end - 1])))
847
            else 
848
                newMatrix { rows = end - start, columns = width m }
849
                   (Columns
850
                       (array (map do v: vec.slice v start end done (asColumns m))))
851
            fi;
852
        fi;
853
    fi;
854

    
855
//!!! doc note: argument order chosen for consistency with std module slice
856
//!!! NB always returns dense matrix, should have sparse version
857
columnSlice m start end = //!!! doc: storage order same as input
858
    if start < 0 then columnSlice m 0 end
859
    elif start > width m then columnSlice m (width m) end
860
    else
861
        if end < start then columnSlice m start start
862
        elif end > width m then columnSlice m start (width m)
863
        else
864
            if not isRowMajor? m then
865
                newMatrix { rows = height m, columns = end - start }
866
                   (Columns
867
                       (array (map ((flip getColumn) m) [start .. end - 1])))
868
            else 
869
                newMatrix { rows = height m, columns = end - start }
870
                   (Rows
871
                       (array (map do v: vec.slice v start end done (asRows m))))
872
            fi;
873
        fi;
874
    fi;
875

    
876
resizedTo newsize m =
877
   (if newsize == (size m) then
878
        m
879
    elif isSparse? m then
880
        // don't call newSparse directly: want to discard
881
        // out-of-range cells
882
        newSparseMatrix newsize ((taggerForTypeOf m) (enumerateSparse m))
883
    elif (height m) == 0 or (width m) == 0 then
884
        zeroMatrixWithTypeOf m newsize;
885
    else
886
        growrows = newsize.rows - (height m);
887
        growcols = newsize.columns - (width m);
888
        rowm = isRowMajor? m;
889
        resizedTo newsize
890
            if rowm and growrows < 0 then
891
                rowSlice m 0 newsize.rows
892
            elif (not rowm) and growcols < 0 then 
893
                columnSlice m 0 newsize.columns
894
            elif growrows < 0 then 
895
                rowSlice m 0 newsize.rows
896
            elif growcols < 0 then 
897
                columnSlice m 0 newsize.columns
898
            else
899
                if growrows > 0 then
900
                    concatVertical
901
                       [m, zeroMatrixWithTypeOf m ((size m) with { rows = growrows })]
902
                else
903
                    concatHorizontal
904
                       [m, zeroMatrixWithTypeOf m ((size m) with { columns = growcols })]
905
                fi
906
            fi
907
    fi);
908

    
909
//!!! doc: always dense
910
repeatedHorizontal n m =
911
   (if n == 1 then m
912
    else
913
        cols = asColumns m;
914
        fromColumns (fold do acc _: acc ++ cols done [] [1..n])
915
    fi);
916

    
917
//!!! doc: always dense
918
repeatedVertical n m =
919
   (if n == 1 then m
920
    else
921
        rows = asRows m;
922
        fromRows (fold do acc _: acc ++ rows done [] [1..n])
923
    fi);
924

    
925
//!!! doc: always dense
926
tiledTo newsize m =
927
    if newsize == size m then
928
        m
929
    elif (height m) == 0 or (width m) == 0 then
930
        zeroMatrixWithTypeOf m newsize;
931
    else    
932
        h = ceil (newsize.columns / (width m));
933
        v = ceil (newsize.rows / (height m));
934
        if isRowMajor? m then
935
            resizedTo newsize (repeatedHorizontal h (repeatedVertical v m))
936
        else
937
            resizedTo newsize (repeatedVertical v (repeatedHorizontal h m))
938
        fi
939
    fi;
940

    
941
minValue m =
942
    if width m == 0 or height m == 0 then 0
943
    elif isSparse? m then
944
        minv ll = fold min (head ll) (tail ll);
945
        minnz = minv (map (.v) (enumerate m));
946
        if minnz > 0 and nonZeroValues m < (width m * height m) then 0
947
        else minnz fi;
948
    elif isRowMajor? m then
949
        vec.min (vec.fromList (map vec.min (asRows m)));
950
    else
951
        vec.min (vec.fromList (map vec.min (asColumns m)));
952
    fi;
953

    
954
maxValue m =
955
    if width m == 0 or height m == 0 then 0
956
    elif isSparse? m then
957
        maxv ll = fold max (head ll) (tail ll);
958
        maxnz = maxv (map (.v) (enumerate m));
959
        if maxnz < 0 and nonZeroValues m < (width m * height m) then 0
960
        else maxnz fi;
961
    elif isRowMajor? m then
962
        vec.max (vec.fromList (map vec.max (asRows m)));
963
    else
964
        vec.max (vec.fromList (map vec.max (asColumns m)));
965
    fi;
966

    
967
total m = 
968
    if isSparse? m then
969
        fold (+) 0 (map (.v) (enumerateSparse m));
970
    elif isRowMajor? m then
971
        fold (+) 0 (map vec.sum (asRows m));
972
    else
973
        fold (+) 0 (map vec.sum (asColumns m));
974
    fi;
975

    
976
mapRows rf m =
977
    fromRows (map rf (asRows m));
978

    
979
mapColumns cf m =
980
    fromColumns (map cf (asColumns m));
981

    
982
format m =
983
    strJoin "\n"
984
       (chunk = 8;
985
        map do b:
986
            c0 = b * chunk;
987
            c1 = b * chunk + chunk - 1;
988
            c1 = if c1 > width m then width m else c1 fi;
989
            [ "\nColumns \(c0) to \(c1)\n",
990
              (map do row:
991
                   map do v:
992
                       strPad ' ' 10
993
                          (if v == 0 then "0.0"
994
                           elif abs v >= 1000.0 or abs v < 0.01 then
995
                               String#format("%.2E", [v as ~Double])
996
                           else
997
                               String#format("%5f", [v as ~Double])
998
                           fi);
999
                   done (vec.list row) |> strJoin "";
1000
               done (asRows (columnSlice m c0 (c1 + 1))) |> strJoin "\n")
1001
            ];
1002
        done [0..floor(width m / chunk)] |> concat);
1003

    
1004
print' = println . format;
1005
eprint' = eprintln . format;
1006

    
1007
json m =
1008
    jsOfList (map do r: jsOfList (map jsOfNum (vec.list r)) done (asRows m));
1009

    
1010
{
1011
    size,
1012
    width,
1013
    height,
1014
    density,
1015
    nonZeroValues,
1016
    at = at',
1017
    getColumn,
1018
    getRow,
1019
    getDiagonal,
1020
    isRowMajor?,
1021
    isSparse?,
1022
    generate,
1023
    constMatrix,
1024
    randomMatrix,
1025
    zeroMatrix,
1026
    identityMatrix,
1027
    equal, //!!! if empty is empty?, why is equal not equal? ?
1028
    equalUnder,
1029
    transposed,
1030
    flipped,
1031
    toRowMajor,
1032
    toColumnMajor,
1033
    toSparse,
1034
    toDense,
1035
    scaled,
1036
    minValue,
1037
    maxValue,
1038
    total,
1039
    asRows,
1040
    asColumns,
1041
    sum = sum',
1042
    difference,
1043
    abs = abs',
1044
    negative,
1045
    filter = filter',
1046
    all = all',
1047
    any = any',
1048
    product,
1049
    entryWiseProduct,
1050
    entryWiseDivide,
1051
    resizedTo,
1052
    tiledTo,
1053
    repeatedHorizontal,
1054
    repeatedVertical,
1055
    concatHorizontal,
1056
    concatVertical,
1057
    rowSlice,
1058
    columnSlice,
1059
    mapRows,
1060
    mapColumns,
1061
    fromRows,
1062
    fromColumns,
1063
    fromLists,
1064
    newMatrix,
1065
    newRowVector,
1066
    newColumnVector,
1067
    newSparseMatrix,
1068
    enumerate,
1069
    format,
1070
    print = print',
1071
    eprint = eprint',
1072
    json,
1073
}
1074
as
1075
{
1076
    size is matrix_t -> { rows is number, columns is number },
1077
    width is matrix_t -> number,
1078
    height is matrix_t -> number,
1079
    density is matrix_t -> number,
1080
    nonZeroValues is matrix_t -> number,
1081
    at is matrix_t -> number -> number -> number,
1082
    getColumn is number -> matrix_t -> vec.vector_t,
1083
    getRow is number -> matrix_t -> vec.vector_t,
1084
    getDiagonal is number -> matrix_t -> vec.vector_t,
1085
    isRowMajor? is matrix_t -> boolean,
1086
    isSparse? is matrix_t -> boolean,
1087
    generate is (number -> number -> number) -> { rows is number, columns is number } -> matrix_t,
1088
    constMatrix is number -> { rows is number, columns is number } -> matrix_t,
1089
    randomMatrix is { rows is number, columns is number } -> matrix_t,
1090
    zeroMatrix is { rows is number, columns is number } -> matrix_t, 
1091
    identityMatrix is { rows is number, columns is number } -> matrix_t, 
1092
    equal is matrix_t -> matrix_t -> boolean,
1093
    equalUnder is (number -> number -> boolean) -> matrix_t -> matrix_t -> boolean,
1094
    transposed is matrix_t -> matrix_t,
1095
    flipped is matrix_t -> matrix_t, 
1096
    toRowMajor is matrix_t -> matrix_t, 
1097
    toColumnMajor is matrix_t -> matrix_t,
1098
    toSparse is matrix_t -> matrix_t,
1099
    toDense is matrix_t -> matrix_t,
1100
    scaled is number -> matrix_t -> matrix_t,
1101
    minValue is matrix_t -> number,
1102
    maxValue is matrix_t -> number,
1103
    total is matrix_t -> number,
1104
    asRows is matrix_t -> list<vec.vector_t>, 
1105
    asColumns is matrix_t -> list<vec.vector_t>,
1106
    sum is list?<matrix_t> -> matrix_t,
1107
    difference is matrix_t -> matrix_t -> matrix_t,
1108
    abs is matrix_t -> matrix_t,
1109
    negative is matrix_t -> matrix_t,
1110
    filter is (number -> boolean) -> matrix_t -> matrix_t,
1111
    all is (number -> boolean) -> matrix_t -> boolean,
1112
    any is (number -> boolean) -> matrix_t -> boolean,
1113
    product is matrix_t -> matrix_t -> matrix_t,
1114
    entryWiseProduct is list?<matrix_t> -> matrix_t,
1115
    entryWiseDivide is matrix_t -> matrix_t -> matrix_t,
1116
    resizedTo is { rows is number, columns is number } -> matrix_t -> matrix_t,
1117
    tiledTo is { rows is number, columns is number } -> matrix_t -> matrix_t,
1118
    repeatedHorizontal is number -> matrix_t -> matrix_t,
1119
    repeatedVertical is number -> matrix_t -> matrix_t,
1120
    concatHorizontal is list<matrix_t> -> matrix_t,
1121
    concatVertical is list<matrix_t> -> matrix_t,
1122
    rowSlice is matrix_t -> number -> number -> matrix_t, 
1123
    columnSlice is matrix_t -> number -> number -> matrix_t,
1124
    mapRows is (vec.vector_t -> vec.vector_t) -> matrix_t -> matrix_t,
1125
    mapColumns is (vec.vector_t -> vec.vector_t) -> matrix_t -> matrix_t,
1126
    fromRows is list<vec.vector_t> -> matrix_t, 
1127
    fromColumns is list<vec.vector_t> -> matrix_t, 
1128
    fromLists is (Rows list<list<number>> | Columns list<list<number>>) -> matrix_t,
1129
    newMatrix is { rows is number, columns is number } -> (Rows list<vec.vector_t> | Columns list<vec.vector_t>) -> matrix_t,
1130
    newRowVector is vec.vector_t -> matrix_t, 
1131
    newColumnVector is vec.vector_t -> matrix_t,
1132
    newSparseMatrix is { rows is number, columns is number } -> (Rows list<{ i is number, j is number, v is number }> | Columns list<{ i is number, j is number, v is number }>) -> matrix_t,
1133
    enumerate is matrix_t -> list<{ i is number, j is number, v is number }>,
1134
    format is matrix_t -> string,
1135
    print is matrix_t -> (),
1136
    eprint is matrix_t -> (),
1137
    json is matrix_t -> json,
1138
}
1139