annotate armadillo-3.900.4/include/armadillo_bits/spglue_times_meat.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2012 Ryan Curtin
Chris@49 2 // Copyright (C) 2012 Conrad Sanderson
Chris@49 3 //
Chris@49 4 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 5 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 7
Chris@49 8
Chris@49 9 //! \addtogroup spglue_times
Chris@49 10 //! @{
Chris@49 11
Chris@49 12
Chris@49 13
Chris@49 14 template<typename T1, typename T2>
Chris@49 15 inline
Chris@49 16 void
Chris@49 17 spglue_times::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times>& X)
Chris@49 18 {
Chris@49 19 arma_extra_debug_sigprint();
Chris@49 20
Chris@49 21 typedef typename T1::elem_type eT;
Chris@49 22
Chris@49 23 const SpProxy<T1> pa(X.A);
Chris@49 24 const SpProxy<T2> pb(X.B);
Chris@49 25
Chris@49 26 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
Chris@49 27
Chris@49 28 if(is_alias == false)
Chris@49 29 {
Chris@49 30 spglue_times::apply_noalias(out, pa, pb);
Chris@49 31 }
Chris@49 32 else
Chris@49 33 {
Chris@49 34 SpMat<eT> tmp;
Chris@49 35 spglue_times::apply_noalias(tmp, pa, pb);
Chris@49 36
Chris@49 37 out.steal_mem(tmp);
Chris@49 38 }
Chris@49 39 }
Chris@49 40
Chris@49 41
Chris@49 42
Chris@49 43 template<typename eT, typename T1, typename T2>
Chris@49 44 arma_hot
Chris@49 45 inline
Chris@49 46 void
Chris@49 47 spglue_times::apply_noalias(SpMat<eT>& c, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
Chris@49 48 {
Chris@49 49 arma_extra_debug_sigprint();
Chris@49 50
Chris@49 51 const uword x_n_rows = pa.get_n_rows();
Chris@49 52 const uword x_n_cols = pa.get_n_cols();
Chris@49 53 const uword y_n_rows = pb.get_n_rows();
Chris@49 54 const uword y_n_cols = pb.get_n_cols();
Chris@49 55
Chris@49 56 arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication");
Chris@49 57
Chris@49 58 // First we must determine the structure of the new matrix (column pointers).
Chris@49 59 // This follows the algorithm described in 'Sparse Matrix Multiplication
Chris@49 60 // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001). Their description of
Chris@49 61 // "SYMBMM" does not include anything about memory allocation. In addition it
Chris@49 62 // does not consider that there may be elements which space may be allocated
Chris@49 63 // for but which evaluate to zero anyway. So we have to modify the algorithm
Chris@49 64 // to work that way. For the "SYMBMM" implementation we will not determine
Chris@49 65 // the row indices but instead just the column pointers.
Chris@49 66
Chris@49 67 //SpMat<typename T1::elem_type> c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0.
Chris@49 68 c.zeros(x_n_rows, y_n_cols);
Chris@49 69
Chris@49 70 //if( (pa.get_n_elem() == 0) || (pb.get_n_elem() == 0) )
Chris@49 71 if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
Chris@49 72 {
Chris@49 73 return;
Chris@49 74 }
Chris@49 75
Chris@49 76 // Auxiliary storage which denotes when items have been found.
Chris@49 77 podarray<uword> index(x_n_rows);
Chris@49 78 index.fill(x_n_rows); // Fill with invalid links.
Chris@49 79
Chris@49 80 typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
Chris@49 81 typename SpProxy<T2>::const_iterator_type y_end = pb.end();
Chris@49 82
Chris@49 83 // SYMBMM: calculate column pointers for resultant matrix to obtain a good
Chris@49 84 // upper bound on the number of nonzero elements.
Chris@49 85 uword cur_col_length = 0;
Chris@49 86 uword last_ind = x_n_rows + 1;
Chris@49 87 do
Chris@49 88 {
Chris@49 89 const uword y_it_row = y_it.row();
Chris@49 90
Chris@49 91 // Look through the column that this point (*y_it) could affect.
Chris@49 92 typename SpProxy<T1>::const_iterator_type x_it = pa.begin_col(y_it_row);
Chris@49 93
Chris@49 94 while(x_it.col() == y_it_row)
Chris@49 95 {
Chris@49 96 // A point at x(i, j) and y(j, k) implies a point at c(i, k).
Chris@49 97 if(index[x_it.row()] == x_n_rows)
Chris@49 98 {
Chris@49 99 index[x_it.row()] = last_ind;
Chris@49 100 last_ind = x_it.row();
Chris@49 101 ++cur_col_length;
Chris@49 102 }
Chris@49 103
Chris@49 104 ++x_it;
Chris@49 105 }
Chris@49 106
Chris@49 107 const uword old_col = y_it.col();
Chris@49 108 ++y_it;
Chris@49 109
Chris@49 110 // See if column incremented.
Chris@49 111 if(old_col != y_it.col())
Chris@49 112 {
Chris@49 113 // Set column pointer (this is not a cumulative count; that is done later).
Chris@49 114 access::rw(c.col_ptrs[old_col + 1]) = cur_col_length;
Chris@49 115 cur_col_length = 0;
Chris@49 116
Chris@49 117 // Return index markers to zero. Use last_ind for traversal.
Chris@49 118 while(last_ind != x_n_rows + 1)
Chris@49 119 {
Chris@49 120 const uword tmp = index[last_ind];
Chris@49 121 index[last_ind] = x_n_rows;
Chris@49 122 last_ind = tmp;
Chris@49 123 }
Chris@49 124 }
Chris@49 125 }
Chris@49 126 while(y_it != y_end);
Chris@49 127
Chris@49 128 // Accumulate column pointers.
Chris@49 129 for(uword i = 0; i < c.n_cols; ++i)
Chris@49 130 {
Chris@49 131 access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i];
Chris@49 132 }
Chris@49 133
Chris@49 134 // Now that we know a decent bound on the number of nonzero elements, allocate
Chris@49 135 // the memory and fill it.
Chris@49 136 c.mem_resize(c.col_ptrs[c.n_cols]);
Chris@49 137
Chris@49 138 // Now the implementation of the NUMBMM algorithm.
Chris@49 139 uword cur_pos = 0; // Current position in c matrix.
Chris@49 140 podarray<eT> sums(x_n_rows); // Partial sums.
Chris@49 141 sums.zeros();
Chris@49 142
Chris@49 143 // setting the size of 'sorted_indices' to x_n_rows is a better-than-nothing guess;
Chris@49 144 // the correct minimum size is determined later
Chris@49 145 podarray<uword> sorted_indices(x_n_rows);
Chris@49 146
Chris@49 147 // last_ind is already set to x_n_rows, and cur_col_length is already set to 0.
Chris@49 148 // We will loop through all columns as necessary.
Chris@49 149 uword cur_col = 0;
Chris@49 150 while(cur_col < c.n_cols)
Chris@49 151 {
Chris@49 152 // Skip to next column with elements in it.
Chris@49 153 while((cur_col < c.n_cols) && (c.col_ptrs[cur_col] == c.col_ptrs[cur_col + 1]))
Chris@49 154 {
Chris@49 155 // Update current column pointer to actual number of nonzero elements up
Chris@49 156 // to this point.
Chris@49 157 access::rw(c.col_ptrs[cur_col]) = cur_pos;
Chris@49 158 ++cur_col;
Chris@49 159 }
Chris@49 160
Chris@49 161 if(cur_col == c.n_cols)
Chris@49 162 {
Chris@49 163 break;
Chris@49 164 }
Chris@49 165
Chris@49 166 // Update current column pointer.
Chris@49 167 access::rw(c.col_ptrs[cur_col]) = cur_pos;
Chris@49 168
Chris@49 169 // Check all elements in this column.
Chris@49 170 typename SpProxy<T2>::const_iterator_type y_col_it = pb.begin_col(cur_col);
Chris@49 171
Chris@49 172 while(y_col_it.col() == cur_col)
Chris@49 173 {
Chris@49 174 // Check all elements in the column of the other matrix corresponding to
Chris@49 175 // the row of this column.
Chris@49 176 typename SpProxy<T1>::const_iterator_type x_col_it = pa.begin_col(y_col_it.row());
Chris@49 177
Chris@49 178 const eT y_value = (*y_col_it);
Chris@49 179
Chris@49 180 while(x_col_it.col() == y_col_it.row())
Chris@49 181 {
Chris@49 182 // A point at x(i, j) and y(j, k) implies a point at c(i, k).
Chris@49 183 // Add to partial sum.
Chris@49 184 const eT x_value = (*x_col_it);
Chris@49 185 sums[x_col_it.row()] += (x_value * y_value);
Chris@49 186
Chris@49 187 // Add point if it hasn't already been marked.
Chris@49 188 if(index[x_col_it.row()] == x_n_rows)
Chris@49 189 {
Chris@49 190 index[x_col_it.row()] = last_ind;
Chris@49 191 last_ind = x_col_it.row();
Chris@49 192 }
Chris@49 193
Chris@49 194 ++x_col_it;
Chris@49 195 }
Chris@49 196
Chris@49 197 ++y_col_it;
Chris@49 198 }
Chris@49 199
Chris@49 200 // Now sort the indices that were used in this column.
Chris@49 201 //podarray<uword> sorted_indices(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
Chris@49 202 sorted_indices.set_min_size(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
Chris@49 203
Chris@49 204 // .set_min_size() can only enlarge the array to the specified size,
Chris@49 205 // hence if we request a smaller size than already allocated,
Chris@49 206 // no new memory allocation is done
Chris@49 207
Chris@49 208
Chris@49 209 uword cur_index = 0;
Chris@49 210 while(last_ind != x_n_rows + 1)
Chris@49 211 {
Chris@49 212 const uword tmp = last_ind;
Chris@49 213
Chris@49 214 // Check that it wasn't a "fake" nonzero element.
Chris@49 215 if(sums[tmp] != eT(0))
Chris@49 216 {
Chris@49 217 // Assign to next open position.
Chris@49 218 sorted_indices[cur_index] = tmp;
Chris@49 219 ++cur_index;
Chris@49 220 }
Chris@49 221
Chris@49 222 last_ind = index[tmp];
Chris@49 223 index[tmp] = x_n_rows;
Chris@49 224 }
Chris@49 225
Chris@49 226 // Now sort the indices.
Chris@49 227 if (cur_index != 0)
Chris@49 228 {
Chris@49 229 op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index);
Chris@49 230
Chris@49 231 for(uword k = 0; k < cur_index; ++k)
Chris@49 232 {
Chris@49 233 const uword row = sorted_indices[k];
Chris@49 234 access::rw(c.row_indices[cur_pos]) = row;
Chris@49 235 access::rw(c.values[cur_pos]) = sums[row];
Chris@49 236 sums[row] = eT(0);
Chris@49 237 ++cur_pos;
Chris@49 238 }
Chris@49 239 }
Chris@49 240
Chris@49 241 // Move to next column.
Chris@49 242 ++cur_col;
Chris@49 243 }
Chris@49 244
Chris@49 245 // Update last column pointer and resize to actual memory size.
Chris@49 246 access::rw(c.col_ptrs[c.n_cols]) = cur_pos;
Chris@49 247 c.mem_resize(cur_pos);
Chris@49 248 }
Chris@49 249
Chris@49 250
Chris@49 251
Chris@49 252 //
Chris@49 253 //
Chris@49 254 // spglue_times2: scalar*(A * B)
Chris@49 255
Chris@49 256
Chris@49 257
Chris@49 258 template<typename T1, typename T2>
Chris@49 259 inline
Chris@49 260 void
Chris@49 261 spglue_times2::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times2>& X)
Chris@49 262 {
Chris@49 263 arma_extra_debug_sigprint();
Chris@49 264
Chris@49 265 typedef typename T1::elem_type eT;
Chris@49 266
Chris@49 267 const SpProxy<T1> pa(X.A);
Chris@49 268 const SpProxy<T2> pb(X.B);
Chris@49 269
Chris@49 270 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
Chris@49 271
Chris@49 272 if(is_alias == false)
Chris@49 273 {
Chris@49 274 spglue_times::apply_noalias(out, pa, pb);
Chris@49 275 }
Chris@49 276 else
Chris@49 277 {
Chris@49 278 SpMat<eT> tmp;
Chris@49 279 spglue_times::apply_noalias(tmp, pa, pb);
Chris@49 280
Chris@49 281 out.steal_mem(tmp);
Chris@49 282 }
Chris@49 283
Chris@49 284 out *= X.aux;
Chris@49 285 }
Chris@49 286
Chris@49 287
Chris@49 288
Chris@49 289 //! @}