Chris@49
|
1 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
2 // Copyright (C) 2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup spglue_times
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename T1, typename T2>
|
Chris@49
|
15 inline
|
Chris@49
|
16 void
|
Chris@49
|
17 spglue_times::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times>& X)
|
Chris@49
|
18 {
|
Chris@49
|
19 arma_extra_debug_sigprint();
|
Chris@49
|
20
|
Chris@49
|
21 typedef typename T1::elem_type eT;
|
Chris@49
|
22
|
Chris@49
|
23 const SpProxy<T1> pa(X.A);
|
Chris@49
|
24 const SpProxy<T2> pb(X.B);
|
Chris@49
|
25
|
Chris@49
|
26 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
|
Chris@49
|
27
|
Chris@49
|
28 if(is_alias == false)
|
Chris@49
|
29 {
|
Chris@49
|
30 spglue_times::apply_noalias(out, pa, pb);
|
Chris@49
|
31 }
|
Chris@49
|
32 else
|
Chris@49
|
33 {
|
Chris@49
|
34 SpMat<eT> tmp;
|
Chris@49
|
35 spglue_times::apply_noalias(tmp, pa, pb);
|
Chris@49
|
36
|
Chris@49
|
37 out.steal_mem(tmp);
|
Chris@49
|
38 }
|
Chris@49
|
39 }
|
Chris@49
|
40
|
Chris@49
|
41
|
Chris@49
|
42
|
Chris@49
|
43 template<typename eT, typename T1, typename T2>
|
Chris@49
|
44 arma_hot
|
Chris@49
|
45 inline
|
Chris@49
|
46 void
|
Chris@49
|
47 spglue_times::apply_noalias(SpMat<eT>& c, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
|
Chris@49
|
48 {
|
Chris@49
|
49 arma_extra_debug_sigprint();
|
Chris@49
|
50
|
Chris@49
|
51 const uword x_n_rows = pa.get_n_rows();
|
Chris@49
|
52 const uword x_n_cols = pa.get_n_cols();
|
Chris@49
|
53 const uword y_n_rows = pb.get_n_rows();
|
Chris@49
|
54 const uword y_n_cols = pb.get_n_cols();
|
Chris@49
|
55
|
Chris@49
|
56 arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication");
|
Chris@49
|
57
|
Chris@49
|
58 // First we must determine the structure of the new matrix (column pointers).
|
Chris@49
|
59 // This follows the algorithm described in 'Sparse Matrix Multiplication
|
Chris@49
|
60 // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001). Their description of
|
Chris@49
|
61 // "SYMBMM" does not include anything about memory allocation. In addition it
|
Chris@49
|
62 // does not consider that there may be elements which space may be allocated
|
Chris@49
|
63 // for but which evaluate to zero anyway. So we have to modify the algorithm
|
Chris@49
|
64 // to work that way. For the "SYMBMM" implementation we will not determine
|
Chris@49
|
65 // the row indices but instead just the column pointers.
|
Chris@49
|
66
|
Chris@49
|
67 //SpMat<typename T1::elem_type> c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0.
|
Chris@49
|
68 c.zeros(x_n_rows, y_n_cols);
|
Chris@49
|
69
|
Chris@49
|
70 //if( (pa.get_n_elem() == 0) || (pb.get_n_elem() == 0) )
|
Chris@49
|
71 if( (pa.get_n_nonzero() == 0) || (pb.get_n_nonzero() == 0) )
|
Chris@49
|
72 {
|
Chris@49
|
73 return;
|
Chris@49
|
74 }
|
Chris@49
|
75
|
Chris@49
|
76 // Auxiliary storage which denotes when items have been found.
|
Chris@49
|
77 podarray<uword> index(x_n_rows);
|
Chris@49
|
78 index.fill(x_n_rows); // Fill with invalid links.
|
Chris@49
|
79
|
Chris@49
|
80 typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
|
Chris@49
|
81 typename SpProxy<T2>::const_iterator_type y_end = pb.end();
|
Chris@49
|
82
|
Chris@49
|
83 // SYMBMM: calculate column pointers for resultant matrix to obtain a good
|
Chris@49
|
84 // upper bound on the number of nonzero elements.
|
Chris@49
|
85 uword cur_col_length = 0;
|
Chris@49
|
86 uword last_ind = x_n_rows + 1;
|
Chris@49
|
87 do
|
Chris@49
|
88 {
|
Chris@49
|
89 const uword y_it_row = y_it.row();
|
Chris@49
|
90
|
Chris@49
|
91 // Look through the column that this point (*y_it) could affect.
|
Chris@49
|
92 typename SpProxy<T1>::const_iterator_type x_it = pa.begin_col(y_it_row);
|
Chris@49
|
93
|
Chris@49
|
94 while(x_it.col() == y_it_row)
|
Chris@49
|
95 {
|
Chris@49
|
96 // A point at x(i, j) and y(j, k) implies a point at c(i, k).
|
Chris@49
|
97 if(index[x_it.row()] == x_n_rows)
|
Chris@49
|
98 {
|
Chris@49
|
99 index[x_it.row()] = last_ind;
|
Chris@49
|
100 last_ind = x_it.row();
|
Chris@49
|
101 ++cur_col_length;
|
Chris@49
|
102 }
|
Chris@49
|
103
|
Chris@49
|
104 ++x_it;
|
Chris@49
|
105 }
|
Chris@49
|
106
|
Chris@49
|
107 const uword old_col = y_it.col();
|
Chris@49
|
108 ++y_it;
|
Chris@49
|
109
|
Chris@49
|
110 // See if column incremented.
|
Chris@49
|
111 if(old_col != y_it.col())
|
Chris@49
|
112 {
|
Chris@49
|
113 // Set column pointer (this is not a cumulative count; that is done later).
|
Chris@49
|
114 access::rw(c.col_ptrs[old_col + 1]) = cur_col_length;
|
Chris@49
|
115 cur_col_length = 0;
|
Chris@49
|
116
|
Chris@49
|
117 // Return index markers to zero. Use last_ind for traversal.
|
Chris@49
|
118 while(last_ind != x_n_rows + 1)
|
Chris@49
|
119 {
|
Chris@49
|
120 const uword tmp = index[last_ind];
|
Chris@49
|
121 index[last_ind] = x_n_rows;
|
Chris@49
|
122 last_ind = tmp;
|
Chris@49
|
123 }
|
Chris@49
|
124 }
|
Chris@49
|
125 }
|
Chris@49
|
126 while(y_it != y_end);
|
Chris@49
|
127
|
Chris@49
|
128 // Accumulate column pointers.
|
Chris@49
|
129 for(uword i = 0; i < c.n_cols; ++i)
|
Chris@49
|
130 {
|
Chris@49
|
131 access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i];
|
Chris@49
|
132 }
|
Chris@49
|
133
|
Chris@49
|
134 // Now that we know a decent bound on the number of nonzero elements, allocate
|
Chris@49
|
135 // the memory and fill it.
|
Chris@49
|
136 c.mem_resize(c.col_ptrs[c.n_cols]);
|
Chris@49
|
137
|
Chris@49
|
138 // Now the implementation of the NUMBMM algorithm.
|
Chris@49
|
139 uword cur_pos = 0; // Current position in c matrix.
|
Chris@49
|
140 podarray<eT> sums(x_n_rows); // Partial sums.
|
Chris@49
|
141 sums.zeros();
|
Chris@49
|
142
|
Chris@49
|
143 // setting the size of 'sorted_indices' to x_n_rows is a better-than-nothing guess;
|
Chris@49
|
144 // the correct minimum size is determined later
|
Chris@49
|
145 podarray<uword> sorted_indices(x_n_rows);
|
Chris@49
|
146
|
Chris@49
|
147 // last_ind is already set to x_n_rows, and cur_col_length is already set to 0.
|
Chris@49
|
148 // We will loop through all columns as necessary.
|
Chris@49
|
149 uword cur_col = 0;
|
Chris@49
|
150 while(cur_col < c.n_cols)
|
Chris@49
|
151 {
|
Chris@49
|
152 // Skip to next column with elements in it.
|
Chris@49
|
153 while((cur_col < c.n_cols) && (c.col_ptrs[cur_col] == c.col_ptrs[cur_col + 1]))
|
Chris@49
|
154 {
|
Chris@49
|
155 // Update current column pointer to actual number of nonzero elements up
|
Chris@49
|
156 // to this point.
|
Chris@49
|
157 access::rw(c.col_ptrs[cur_col]) = cur_pos;
|
Chris@49
|
158 ++cur_col;
|
Chris@49
|
159 }
|
Chris@49
|
160
|
Chris@49
|
161 if(cur_col == c.n_cols)
|
Chris@49
|
162 {
|
Chris@49
|
163 break;
|
Chris@49
|
164 }
|
Chris@49
|
165
|
Chris@49
|
166 // Update current column pointer.
|
Chris@49
|
167 access::rw(c.col_ptrs[cur_col]) = cur_pos;
|
Chris@49
|
168
|
Chris@49
|
169 // Check all elements in this column.
|
Chris@49
|
170 typename SpProxy<T2>::const_iterator_type y_col_it = pb.begin_col(cur_col);
|
Chris@49
|
171
|
Chris@49
|
172 while(y_col_it.col() == cur_col)
|
Chris@49
|
173 {
|
Chris@49
|
174 // Check all elements in the column of the other matrix corresponding to
|
Chris@49
|
175 // the row of this column.
|
Chris@49
|
176 typename SpProxy<T1>::const_iterator_type x_col_it = pa.begin_col(y_col_it.row());
|
Chris@49
|
177
|
Chris@49
|
178 const eT y_value = (*y_col_it);
|
Chris@49
|
179
|
Chris@49
|
180 while(x_col_it.col() == y_col_it.row())
|
Chris@49
|
181 {
|
Chris@49
|
182 // A point at x(i, j) and y(j, k) implies a point at c(i, k).
|
Chris@49
|
183 // Add to partial sum.
|
Chris@49
|
184 const eT x_value = (*x_col_it);
|
Chris@49
|
185 sums[x_col_it.row()] += (x_value * y_value);
|
Chris@49
|
186
|
Chris@49
|
187 // Add point if it hasn't already been marked.
|
Chris@49
|
188 if(index[x_col_it.row()] == x_n_rows)
|
Chris@49
|
189 {
|
Chris@49
|
190 index[x_col_it.row()] = last_ind;
|
Chris@49
|
191 last_ind = x_col_it.row();
|
Chris@49
|
192 }
|
Chris@49
|
193
|
Chris@49
|
194 ++x_col_it;
|
Chris@49
|
195 }
|
Chris@49
|
196
|
Chris@49
|
197 ++y_col_it;
|
Chris@49
|
198 }
|
Chris@49
|
199
|
Chris@49
|
200 // Now sort the indices that were used in this column.
|
Chris@49
|
201 //podarray<uword> sorted_indices(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
|
Chris@49
|
202 sorted_indices.set_min_size(c.col_ptrs[cur_col + 1] - c.col_ptrs[cur_col]);
|
Chris@49
|
203
|
Chris@49
|
204 // .set_min_size() can only enlarge the array to the specified size,
|
Chris@49
|
205 // hence if we request a smaller size than already allocated,
|
Chris@49
|
206 // no new memory allocation is done
|
Chris@49
|
207
|
Chris@49
|
208
|
Chris@49
|
209 uword cur_index = 0;
|
Chris@49
|
210 while(last_ind != x_n_rows + 1)
|
Chris@49
|
211 {
|
Chris@49
|
212 const uword tmp = last_ind;
|
Chris@49
|
213
|
Chris@49
|
214 // Check that it wasn't a "fake" nonzero element.
|
Chris@49
|
215 if(sums[tmp] != eT(0))
|
Chris@49
|
216 {
|
Chris@49
|
217 // Assign to next open position.
|
Chris@49
|
218 sorted_indices[cur_index] = tmp;
|
Chris@49
|
219 ++cur_index;
|
Chris@49
|
220 }
|
Chris@49
|
221
|
Chris@49
|
222 last_ind = index[tmp];
|
Chris@49
|
223 index[tmp] = x_n_rows;
|
Chris@49
|
224 }
|
Chris@49
|
225
|
Chris@49
|
226 // Now sort the indices.
|
Chris@49
|
227 if (cur_index != 0)
|
Chris@49
|
228 {
|
Chris@49
|
229 op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index);
|
Chris@49
|
230
|
Chris@49
|
231 for(uword k = 0; k < cur_index; ++k)
|
Chris@49
|
232 {
|
Chris@49
|
233 const uword row = sorted_indices[k];
|
Chris@49
|
234 access::rw(c.row_indices[cur_pos]) = row;
|
Chris@49
|
235 access::rw(c.values[cur_pos]) = sums[row];
|
Chris@49
|
236 sums[row] = eT(0);
|
Chris@49
|
237 ++cur_pos;
|
Chris@49
|
238 }
|
Chris@49
|
239 }
|
Chris@49
|
240
|
Chris@49
|
241 // Move to next column.
|
Chris@49
|
242 ++cur_col;
|
Chris@49
|
243 }
|
Chris@49
|
244
|
Chris@49
|
245 // Update last column pointer and resize to actual memory size.
|
Chris@49
|
246 access::rw(c.col_ptrs[c.n_cols]) = cur_pos;
|
Chris@49
|
247 c.mem_resize(cur_pos);
|
Chris@49
|
248 }
|
Chris@49
|
249
|
Chris@49
|
250
|
Chris@49
|
251
|
Chris@49
|
252 //
|
Chris@49
|
253 //
|
Chris@49
|
254 // spglue_times2: scalar*(A * B)
|
Chris@49
|
255
|
Chris@49
|
256
|
Chris@49
|
257
|
Chris@49
|
258 template<typename T1, typename T2>
|
Chris@49
|
259 inline
|
Chris@49
|
260 void
|
Chris@49
|
261 spglue_times2::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_times2>& X)
|
Chris@49
|
262 {
|
Chris@49
|
263 arma_extra_debug_sigprint();
|
Chris@49
|
264
|
Chris@49
|
265 typedef typename T1::elem_type eT;
|
Chris@49
|
266
|
Chris@49
|
267 const SpProxy<T1> pa(X.A);
|
Chris@49
|
268 const SpProxy<T2> pb(X.B);
|
Chris@49
|
269
|
Chris@49
|
270 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
|
Chris@49
|
271
|
Chris@49
|
272 if(is_alias == false)
|
Chris@49
|
273 {
|
Chris@49
|
274 spglue_times::apply_noalias(out, pa, pb);
|
Chris@49
|
275 }
|
Chris@49
|
276 else
|
Chris@49
|
277 {
|
Chris@49
|
278 SpMat<eT> tmp;
|
Chris@49
|
279 spglue_times::apply_noalias(tmp, pa, pb);
|
Chris@49
|
280
|
Chris@49
|
281 out.steal_mem(tmp);
|
Chris@49
|
282 }
|
Chris@49
|
283
|
Chris@49
|
284 out *= X.aux;
|
Chris@49
|
285 }
|
Chris@49
|
286
|
Chris@49
|
287
|
Chris@49
|
288
|
Chris@49
|
289 //! @}
|