Chris@49: // Copyright (C) 2012 Ryan Curtin Chris@49: // Copyright (C) 2012 Conrad Sanderson Chris@49: // Chris@49: // This Source Code Form is subject to the terms of the Mozilla Public Chris@49: // License, v. 2.0. If a copy of the MPL was not distributed with this Chris@49: // file, You can obtain one at http://mozilla.org/MPL/2.0/. Chris@49: Chris@49: Chris@49: //! \addtogroup spglue_times Chris@49: //! @{ Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: spglue_times::apply(SpMat& out, const SpGlue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const SpProxy pa(X.A); Chris@49: const SpProxy pb(X.B); Chris@49: Chris@49: const bool is_alias = pa.is_alias(out) || pb.is_alias(out); Chris@49: Chris@49: if(is_alias == false) Chris@49: { Chris@49: spglue_times::apply_noalias(out, pa, pb); Chris@49: } Chris@49: else Chris@49: { Chris@49: SpMat tmp; Chris@49: spglue_times::apply_noalias(tmp, pa, pb); Chris@49: Chris@49: out.steal_mem(tmp); Chris@49: } Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: arma_hot Chris@49: inline Chris@49: void Chris@49: spglue_times::apply_noalias(SpMat& c, const SpProxy& pa, const SpProxy& pb) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: const uword x_n_rows = pa.get_n_rows(); Chris@49: const uword x_n_cols = pa.get_n_cols(); Chris@49: const uword y_n_rows = pb.get_n_rows(); Chris@49: const uword y_n_cols = pb.get_n_cols(); Chris@49: Chris@49: arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication"); Chris@49: Chris@49: // First we must determine the structure of the new matrix (column pointers). Chris@49: // This follows the algorithm described in 'Sparse Matrix Multiplication Chris@49: // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001). Their description of Chris@49: // "SYMBMM" does not include anything about memory allocation. In addition it Chris@49: // does not consider that there may be elements which space may be allocated Chris@49: // for but which evaluate to zero anyway. So we have to modify the algorithm Chris@49: // to work that way. For the "SYMBMM" implementation we will not determine Chris@49: // the row indices but instead just the column pointers. Chris@49: Chris@49: //SpMat c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0. Chris@49: c.zeros(x_n_rows, y_n_cols); Chris@49: Chris@49: //if( (pa.get_n_elem() == 0) || (pb.get_n_elem() == 0) ) Chris@49: if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) ) Chris@49: { Chris@49: return; Chris@49: } Chris@49: Chris@49: // Auxiliary storage which denotes when items have been found. Chris@49: podarray index(x_n_rows); Chris@49: index.fill(x_n_rows); // Fill with invalid links. Chris@49: Chris@49: typename SpProxy::const_iterator_type y_it = pb.begin(); Chris@49: typename SpProxy::const_iterator_type y_end = pb.end(); Chris@49: Chris@49: // SYMBMM: calculate column pointers for resultant matrix to obtain a good Chris@49: // upper bound on the number of nonzero elements. Chris@49: uword cur_col_length = 0; Chris@49: uword last_ind = x_n_rows + 1; Chris@49: do Chris@49: { Chris@49: const uword y_it_row = y_it.row(); Chris@49: Chris@49: // Look through the column that this point (*y_it) could affect. Chris@49: typename SpProxy::const_iterator_type x_it = pa.begin_col(y_it_row); Chris@49: Chris@49: while(x_it.col() == y_it_row) Chris@49: { Chris@49: // A point at x(i, j) and y(j, k) implies a point at c(i, k). Chris@49: if(index[x_it.row()] == x_n_rows) Chris@49: { Chris@49: index[x_it.row()] = last_ind; Chris@49: last_ind = x_it.row(); Chris@49: ++cur_col_length; Chris@49: } Chris@49: Chris@49: ++x_it; Chris@49: } Chris@49: Chris@49: const uword old_col = y_it.col(); Chris@49: ++y_it; Chris@49: Chris@49: // See if column incremented. Chris@49: if(old_col != y_it.col()) Chris@49: { Chris@49: // Set column pointer (this is not a cumulative count; that is done later). Chris@49: access::rw(c.col_ptrs[old_col + 1]) = cur_col_length; Chris@49: cur_col_length = 0; Chris@49: Chris@49: // Return index markers to zero. Use last_ind for traversal. Chris@49: while(last_ind != x_n_rows + 1) Chris@49: { Chris@49: const uword tmp = index[last_ind]; Chris@49: index[last_ind] = x_n_rows; Chris@49: last_ind = tmp; Chris@49: } Chris@49: } Chris@49: } Chris@49: while(y_it != y_end); Chris@49: Chris@49: // Accumulate column pointers. Chris@49: for(uword i = 0; i < c.n_cols; ++i) Chris@49: { Chris@49: access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i]; Chris@49: } Chris@49: Chris@49: // Now that we know a decent bound on the number of nonzero elements, allocate Chris@49: // the memory and fill it. Chris@49: c.mem_resize(c.col_ptrs[c.n_cols]); Chris@49: Chris@49: // Now the implementation of the NUMBMM algorithm. Chris@49: uword cur_pos = 0; // Current position in c matrix. Chris@49: podarray sums(x_n_rows); // Partial sums. Chris@49: sums.zeros(); Chris@49: Chris@49: // setting the size of 'sorted_indices' to x_n_rows is a better-than-nothing guess; Chris@49: // the correct minimum size is determined later Chris@49: podarray sorted_indices(x_n_rows); Chris@49: Chris@49: // last_ind is already set to x_n_rows, and cur_col_length is already set to 0. Chris@49: // We will loop through all columns as necessary. Chris@49: uword cur_col = 0; Chris@49: while(cur_col < c.n_cols) Chris@49: { Chris@49: // Skip to next column with elements in it. Chris@49: while((cur_col < c.n_cols) && (c.col_ptrs[cur_col] == c.col_ptrs[cur_col + 1])) Chris@49: { Chris@49: // Update current column pointer to actual number of nonzero elements up Chris@49: // to this point. Chris@49: access::rw(c.col_ptrs[cur_col]) = cur_pos; Chris@49: ++cur_col; Chris@49: } Chris@49: Chris@49: if(cur_col == c.n_cols) Chris@49: { Chris@49: break; Chris@49: } Chris@49: Chris@49: // Update current column pointer. Chris@49: access::rw(c.col_ptrs[cur_col]) = cur_pos; Chris@49: Chris@49: // Check all elements in this column. Chris@49: typename SpProxy::const_iterator_type y_col_it = pb.begin_col(cur_col); Chris@49: Chris@49: while(y_col_it.col() == cur_col) Chris@49: { Chris@49: // Check all elements in the column of the other matrix corresponding to Chris@49: // the row of this column. Chris@49: typename SpProxy::const_iterator_type x_col_it = pa.begin_col(y_col_it.row()); Chris@49: Chris@49: const eT y_value = (*y_col_it); Chris@49: Chris@49: while(x_col_it.col() == y_col_it.row()) Chris@49: { Chris@49: // A point at x(i, j) and y(j, k) implies a point at c(i, k). Chris@49: // Add to partial sum. Chris@49: const eT x_value = (*x_col_it); Chris@49: sums[x_col_it.row()] += (x_value * y_value); Chris@49: Chris@49: // Add point if it hasn't already been marked. Chris@49: if(index[x_col_it.row()] == x_n_rows) Chris@49: { Chris@49: index[x_col_it.row()] = last_ind; Chris@49: last_ind = x_col_it.row(); Chris@49: } Chris@49: Chris@49: ++x_col_it; Chris@49: } Chris@49: Chris@49: ++y_col_it; Chris@49: } Chris@49: Chris@49: // Now sort the indices that were used in this column. Chris@49: //podarray sorted_indices(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]); Chris@49: sorted_indices.set_min_size(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]); Chris@49: Chris@49: // .set_min_size() can only enlarge the array to the specified size, Chris@49: // hence if we request a smaller size than already allocated, Chris@49: // no new memory allocation is done Chris@49: Chris@49: Chris@49: uword cur_index = 0; Chris@49: while(last_ind != x_n_rows + 1) Chris@49: { Chris@49: const uword tmp = last_ind; Chris@49: Chris@49: // Check that it wasn't a "fake" nonzero element. Chris@49: if(sums[tmp] != eT(0)) Chris@49: { Chris@49: // Assign to next open position. Chris@49: sorted_indices[cur_index] = tmp; Chris@49: ++cur_index; Chris@49: } Chris@49: Chris@49: last_ind = index[tmp]; Chris@49: index[tmp] = x_n_rows; Chris@49: } Chris@49: Chris@49: // Now sort the indices. Chris@49: if (cur_index != 0) Chris@49: { Chris@49: op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index); Chris@49: Chris@49: for(uword k = 0; k < cur_index; ++k) Chris@49: { Chris@49: const uword row = sorted_indices[k]; Chris@49: access::rw(c.row_indices[cur_pos]) = row; Chris@49: access::rw(c.values[cur_pos]) = sums[row]; Chris@49: sums[row] = eT(0); Chris@49: ++cur_pos; Chris@49: } Chris@49: } Chris@49: Chris@49: // Move to next column. Chris@49: ++cur_col; Chris@49: } Chris@49: Chris@49: // Update last column pointer and resize to actual memory size. Chris@49: access::rw(c.col_ptrs[c.n_cols]) = cur_pos; Chris@49: c.mem_resize(cur_pos); Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: // Chris@49: // Chris@49: // spglue_times2: scalar*(A * B) Chris@49: Chris@49: Chris@49: Chris@49: template Chris@49: inline Chris@49: void Chris@49: spglue_times2::apply(SpMat& out, const SpGlue& X) Chris@49: { Chris@49: arma_extra_debug_sigprint(); Chris@49: Chris@49: typedef typename T1::elem_type eT; Chris@49: Chris@49: const SpProxy pa(X.A); Chris@49: const SpProxy pb(X.B); Chris@49: Chris@49: const bool is_alias = pa.is_alias(out) || pb.is_alias(out); Chris@49: Chris@49: if(is_alias == false) Chris@49: { Chris@49: spglue_times::apply_noalias(out, pa, pb); Chris@49: } Chris@49: else Chris@49: { Chris@49: SpMat tmp; Chris@49: spglue_times::apply_noalias(tmp, pa, pb); Chris@49: Chris@49: out.steal_mem(tmp); Chris@49: } Chris@49: Chris@49: out *= X.aux; Chris@49: } Chris@49: Chris@49: Chris@49: Chris@49: //! @}