Mercurial > hg > may
view yetilab/matrix/matrix.yeti @ 178:032c4986b6b0
Implement and test matrix concat
author | Chris Cannam |
---|---|
date | Thu, 02 May 2013 21:58:58 +0100 |
parents | 370d9350495c |
children | fd16551c2fc8 |
line wrap: on
line source
module yetilab.matrix.matrix; // A matrix is an array of fvectors (i.e. primitive double[]s). // A matrix can be stored in either column-major (the default) or // row-major format. Storage order is an efficiency concern only: // every API function operating on matrix objects will return the same // result regardless of storage order. (The transpose function just // switches the row/column order without moving the elements.) vec = load yetilab.block.fvector; block = load yetilab.block.block; bf = load yetilab.block.blockfuncs; make d = { get data () = d, get size () = case d of RowM r: major = length r; { rows = major, columns = if major > 0 then vec.length r[0] else 0 fi, }; ColM c: major = length c; { rows = if major > 0 then vec.length c[0] else 0 fi, columns = major, }; esac, getColumn j = case d of RowM rows: block.fromList (map do i: getAt i j done [0..length rows-1]); ColM cols: block.block cols[j]; esac, getRow i = case d of RowM rows: block.block rows[i]; ColM cols: block.fromList (map do j: getAt i j done [0..length cols-1]); esac, getAt row col = case d of RowM rows: r = rows[row]; (r is ~double[])[col]; ColM cols: c = cols[col]; (c is ~double[])[row]; esac, setAt row col n = //!!! dangerous, could modify copies -- should it be allowed? case d of RowM rows: r = rows[row]; (r is ~double[])[col] := n; ColM cols: c = cols[col]; (c is ~double[])[row] := n; esac, get isRowMajor? () = case d of RowM _: true; ColM _: false; esac, }; newStorage { rows, columns } = if rows < 1 then array [] else array (map \(vec.zeros rows) [1..columns]) fi; zeroMatrix { rows, columns } = make (ColM (newStorage { rows, columns })); generate f { rows, columns } = (m = newStorage { rows, columns }; for [0..columns-1] do col: for [0..rows-1] do row: m[col][row] := f row col; done; done; make (ColM m)); constMatrix n = generate do row col: n done; randomMatrix = generate do row col: Math#random() done; identityMatrix = constMatrix 1; zeroSizeMatrix () = zeroMatrix { rows = 0, columns = 0 }; width m = m.size.columns; height m = m.size.rows; transposed m = make (case m.data of RowM d: ColM d; ColM d: RowM d; esac); flipped m = if m.isRowMajor? then generate do row col: m.getAt row col done m.size; else transposed (generate do row col: m.getAt col row done { rows = m.size.columns, columns = m.size.rows }); fi; toRowMajor m = if m.isRowMajor? then m else flipped m fi; toColumnMajor m = if not m.isRowMajor? then m else flipped m fi; // Matrices with different storage order but the same contents are // equal (but comparing them is slow) equal m1 m2 = if m1.size != m2.size then false elif m1.isRowMajor? != m2.isRowMajor? then equal (flipped m1) m2; else compare d1 d2 = all id (map2 vec.equal d1 d2); case m1.data of RowM d1: case m2.data of RowM d2: compare d1 d2; _: false; esac; ColM d1: case m2.data of ColM d2: compare d1 d2; _: false; esac; esac fi; copyOf m = (copyOfData d = (array (map vec.copyOf d)); make (case m.data of RowM d: RowM (copyOfData d); ColM d: ColM (copyOfData d); esac)); newMatrix type data = //!!! NB does not copy data (tagger = case type of RowMajor (): RowM; ColumnMajor (): ColM esac; if empty? data or block.empty? (head data) then zeroMatrix { rows = 0, columns = 0 } else make (tagger (array (map block.data data))) fi); newRowVector data = //!!! NB does not copy data make (RowM (array [block.data data])); newColumnVector data = //!!! NB does not copy data make (ColM (array [block.data data])); scaled factor m = generate do row col: factor * m.getAt row col done m.size; resizedTo newsize m = (oldsize = m.size; if newsize == oldsize then m else generate do row col: if row < oldsize.rows and col < oldsize.columns then m.getAt row col else 0 fi done newsize; fi); sum' m1 m2 = if m1.size != m2.size then failWith "Matrices are not the same size: \(m1.size), \(m2.size)"; else generate do row col: m1.getAt row col + m2.getAt row col done m1.size; fi; product m1 m2 = if m1.size.columns != m2.size.rows then failWith "Matrix dimensions incompatible: \(m1.size), \(m2.size) (\(m1.size.columns != m2.size.rows)"; else generate do row col: bf.sum (bf.multiply (m1.getRow row) (m2.getColumn col)) done { rows = m1.size.rows, columns = m2.size.columns } fi; asRows m = map m.getRow [0 .. m.size.rows - 1]; asColumns m = map m.getColumn [0 .. m.size.columns - 1]; concatAgainstGrain tagger getter counter mm = (n = counter (head mm).size; make (tagger (array (map do i: block.data (block.concat (map do m: getter m i done mm)) done [0..n-1])))); concatWithGrain tagger getter counter mm = make (tagger (array (concat (map do m: n = counter m.size; map do i: block.data (getter m i) done [0..n-1] done mm)))); checkDimensionsFor direction first mm = (counter = if direction == Horizontal () then (.rows) else (.columns) fi; n = counter first.size; if not (all id (map do m: counter m.size == n done mm)) then failWith "Matrix dimensions incompatible for concat (found \(map do m: counter m.size done mm) not all of which are \(n))"; fi); concat direction mm = //!!! storage order is taken from first matrix in sequence if empty? mm then zeroSizeMatrix () else first = head mm; checkDimensionsFor direction first mm; row = first.isRowMajor?; // horizontal, row-major: against grain with rows // horizontal, col-major: with grain with cols // vertical, row-major: with grain with rows // vertical, col-major: against grain with cols case direction of Horizontal (): if row then concatAgainstGrain RowM (.getRow) (.rows) mm; else concatWithGrain ColM (.getColumn) (.columns) mm; fi; Vertical (): if row then concatWithGrain RowM (.getRow) (.rows) mm; else concatAgainstGrain ColM (.getColumn) (.columns) mm; fi; esac; fi; { constMatrix, randomMatrix, zeroMatrix, identityMatrix, zeroSizeMatrix, generate, width, height, equal, copyOf, transposed, flipped, toRowMajor, toColumnMajor, scaled, resizedTo, asRows, asColumns, sum = sum', product, concat, newMatrix, newRowVector, newColumnVector, }