Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-3.900.4/include/armadillo_bits/SpMat_meat.hpp @ 49:1ec0e2823891
Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author | Chris Cannam |
---|---|
date | Thu, 13 Jun 2013 10:25:24 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
48:69251e11a913 | 49:1ec0e2823891 |
---|---|
1 // Copyright (C) 2011-2013 Ryan Curtin | |
2 // Copyright (C) 2012-2013 Conrad Sanderson | |
3 // Copyright (C) 2011 Matthew Amidon | |
4 // | |
5 // This Source Code Form is subject to the terms of the Mozilla Public | |
6 // License, v. 2.0. If a copy of the MPL was not distributed with this | |
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |
8 | |
9 //! \addtogroup SpMat | |
10 //! @{ | |
11 | |
12 /** | |
13 * Initialize a sparse matrix with size 0x0 (empty). | |
14 */ | |
15 template<typename eT> | |
16 inline | |
17 SpMat<eT>::SpMat() | |
18 : n_rows(0) | |
19 , n_cols(0) | |
20 , n_elem(0) | |
21 , n_nonzero(0) | |
22 , vec_state(0) | |
23 , values(memory::acquire_chunked<eT>(1)) | |
24 , row_indices(memory::acquire_chunked<uword>(1)) | |
25 , col_ptrs(memory::acquire<uword>(2)) | |
26 { | |
27 arma_extra_debug_sigprint_this(this); | |
28 | |
29 access::rw(values[0]) = 0; | |
30 access::rw(row_indices[0]) = 0; | |
31 | |
32 access::rw(col_ptrs[0]) = 0; // No elements. | |
33 access::rw(col_ptrs[1]) = std::numeric_limits<uword>::max(); | |
34 } | |
35 | |
36 | |
37 | |
38 /** | |
39 * Clean up the memory of a sparse matrix and destruct it. | |
40 */ | |
41 template<typename eT> | |
42 inline | |
43 SpMat<eT>::~SpMat() | |
44 { | |
45 arma_extra_debug_sigprint_this(this); | |
46 | |
47 // If necessary, release the memory. | |
48 if (values) | |
49 { | |
50 // values being non-NULL implies row_indices is non-NULL. | |
51 memory::release(access::rw(values)); | |
52 memory::release(access::rw(row_indices)); | |
53 } | |
54 | |
55 // Column pointers always must be deleted. | |
56 memory::release(access::rw(col_ptrs)); | |
57 } | |
58 | |
59 | |
60 | |
61 /** | |
62 * Constructor with size given. | |
63 */ | |
64 template<typename eT> | |
65 inline | |
66 SpMat<eT>::SpMat(const uword in_rows, const uword in_cols) | |
67 : n_rows(0) | |
68 , n_cols(0) | |
69 , n_elem(0) | |
70 , n_nonzero(0) | |
71 , vec_state(0) | |
72 , values(NULL) | |
73 , row_indices(NULL) | |
74 , col_ptrs(NULL) | |
75 { | |
76 arma_extra_debug_sigprint_this(this); | |
77 | |
78 init(in_rows, in_cols); | |
79 } | |
80 | |
81 | |
82 | |
83 /** | |
84 * Assemble from text. | |
85 */ | |
86 template<typename eT> | |
87 inline | |
88 SpMat<eT>::SpMat(const char* text) | |
89 : n_rows(0) | |
90 , n_cols(0) | |
91 , n_elem(0) | |
92 , n_nonzero(0) | |
93 , vec_state(0) | |
94 , values(NULL) | |
95 , row_indices(NULL) | |
96 , col_ptrs(NULL) | |
97 { | |
98 arma_extra_debug_sigprint_this(this); | |
99 | |
100 init(std::string(text)); | |
101 } | |
102 | |
103 | |
104 | |
105 template<typename eT> | |
106 inline | |
107 const SpMat<eT>& | |
108 SpMat<eT>::operator=(const char* text) | |
109 { | |
110 arma_extra_debug_sigprint(); | |
111 | |
112 init(std::string(text)); | |
113 } | |
114 | |
115 | |
116 | |
117 template<typename eT> | |
118 inline | |
119 SpMat<eT>::SpMat(const std::string& text) | |
120 : n_rows(0) | |
121 , n_cols(0) | |
122 , n_elem(0) | |
123 , n_nonzero(0) | |
124 , vec_state(0) | |
125 , values(NULL) | |
126 , row_indices(NULL) | |
127 , col_ptrs(NULL) | |
128 { | |
129 arma_extra_debug_sigprint(); | |
130 | |
131 init(text); | |
132 } | |
133 | |
134 | |
135 | |
136 template<typename eT> | |
137 inline | |
138 const SpMat<eT>& | |
139 SpMat<eT>::operator=(const std::string& text) | |
140 { | |
141 arma_extra_debug_sigprint(); | |
142 | |
143 init(text); | |
144 } | |
145 | |
146 | |
147 | |
148 template<typename eT> | |
149 inline | |
150 SpMat<eT>::SpMat(const SpMat<eT>& x) | |
151 : n_rows(0) | |
152 , n_cols(0) | |
153 , n_elem(0) | |
154 , n_nonzero(0) | |
155 , vec_state(0) | |
156 , values(NULL) | |
157 , row_indices(NULL) | |
158 , col_ptrs(NULL) | |
159 { | |
160 arma_extra_debug_sigprint_this(this); | |
161 | |
162 init(x); | |
163 } | |
164 | |
165 | |
166 | |
167 //! Insert a large number of values at once. | |
168 //! locations.row[0] should be row indices, locations.row[1] should be column indices, | |
169 //! and values should be the corresponding values. | |
170 //! If sort_locations is false, then it is assumed that the locations and values | |
171 //! are already sorted in column-major ordering. | |
172 template<typename eT> | |
173 template<typename T1, typename T2> | |
174 inline | |
175 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const bool sort_locations) | |
176 : n_rows(0) | |
177 , n_cols(0) | |
178 , n_elem(0) | |
179 , n_nonzero(0) | |
180 , vec_state(0) | |
181 , values(NULL) | |
182 , row_indices(NULL) | |
183 , col_ptrs(NULL) | |
184 { | |
185 arma_extra_debug_sigprint_this(this); | |
186 | |
187 const unwrap<T1> locs_tmp( locations_expr.get_ref() ); | |
188 const Mat<uword>& locs = locs_tmp.M; | |
189 | |
190 const unwrap<T2> vals_tmp( vals_expr.get_ref() ); | |
191 const Mat<eT>& vals = vals_tmp.M; | |
192 | |
193 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); | |
194 | |
195 arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values"); | |
196 | |
197 // If there are no elements in the list, max() will fail. | |
198 if (locs.n_cols == 0) | |
199 { | |
200 init(0, 0); | |
201 return; | |
202 } | |
203 | |
204 arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows"); | |
205 | |
206 // Automatically determine size (and check if it's sorted). | |
207 uvec bounds = arma::max(locs, 1); | |
208 init(bounds[0] + 1, bounds[1] + 1); | |
209 | |
210 // Resize to correct number of elements. | |
211 mem_resize(vals.n_elem); | |
212 | |
213 // Reset column pointers to zero. | |
214 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1); | |
215 | |
216 bool actually_sorted = true; | |
217 if(sort_locations == true) | |
218 { | |
219 // sort_index() uses std::sort() which may use quicksort... so we better | |
220 // make sure it's not already sorted before taking an O(N^2) sort penalty. | |
221 for (uword i = 1; i < locs.n_cols; ++i) | |
222 { | |
223 if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1))) | |
224 { | |
225 actually_sorted = false; | |
226 break; | |
227 } | |
228 } | |
229 | |
230 if(actually_sorted == false) | |
231 { | |
232 // This may not be the fastest possible implementation but it maximizes code reuse. | |
233 Col<uword> abslocs(locs.n_cols); | |
234 | |
235 for (uword i = 0; i < locs.n_cols; ++i) | |
236 { | |
237 abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i); | |
238 } | |
239 | |
240 // Now we will sort with sort_index(). | |
241 uvec sorted_indices = sort_index(abslocs); // Ascending sort. | |
242 | |
243 // Now we add the elements in this sorted order. | |
244 for (uword i = 0; i < sorted_indices.n_elem; ++i) | |
245 { | |
246 arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index"); | |
247 arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index"); | |
248 | |
249 access::rw(values[i]) = vals[sorted_indices[i]]; | |
250 access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]); | |
251 | |
252 access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++; | |
253 } | |
254 } | |
255 } | |
256 | |
257 if( (sort_locations == false) || (actually_sorted == true) ) | |
258 { | |
259 // Now set the values and row indices correctly. | |
260 // Increment the column pointers in each column (so they are column "counts"). | |
261 for (uword i = 0; i < vals.n_elem; ++i) | |
262 { | |
263 arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index"); | |
264 arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index"); | |
265 | |
266 // Check ordering in debug mode. | |
267 if(i > 0) | |
268 { | |
269 arma_debug_check | |
270 ( | |
271 ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ), | |
272 "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering" | |
273 ); | |
274 arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list"); | |
275 } | |
276 | |
277 access::rw(values[i]) = vals[i]; | |
278 access::rw(row_indices[i]) = locs.at(0, i); | |
279 | |
280 access::rw(col_ptrs[locs.at(1, i) + 1])++; | |
281 } | |
282 } | |
283 | |
284 // Now fix the column pointers. | |
285 for (uword i = 0; i <= n_cols; ++i) | |
286 { | |
287 access::rw(col_ptrs[i + 1]) += col_ptrs[i]; | |
288 } | |
289 } | |
290 | |
291 | |
292 | |
293 //! Insert a large number of values at once. | |
294 //! locations.row[0] should be row indices, locations.row[1] should be column indices, | |
295 //! and values should be the corresponding values. | |
296 //! If sort_locations is false, then it is assumed that the locations and values | |
297 //! are already sorted in column-major ordering. | |
298 //! In this constructor the size is explicitly given. | |
299 template<typename eT> | |
300 template<typename T1, typename T2> | |
301 inline | |
302 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations) | |
303 : n_rows(0) | |
304 , n_cols(0) | |
305 , n_elem(0) | |
306 , n_nonzero(0) | |
307 , vec_state(0) | |
308 , values(NULL) | |
309 , row_indices(NULL) | |
310 , col_ptrs(NULL) | |
311 { | |
312 arma_extra_debug_sigprint_this(this); | |
313 | |
314 init(in_n_rows, in_n_cols); | |
315 | |
316 const unwrap<T1> locs_tmp( locations_expr.get_ref() ); | |
317 const Mat<uword>& locs = locs_tmp.M; | |
318 | |
319 const unwrap<T2> vals_tmp( vals_expr.get_ref() ); | |
320 const Mat<eT>& vals = vals_tmp.M; | |
321 | |
322 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); | |
323 | |
324 arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows"); | |
325 | |
326 arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values"); | |
327 | |
328 // Resize to correct number of elements. | |
329 mem_resize(vals.n_elem); | |
330 | |
331 // Reset column pointers to zero. | |
332 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1); | |
333 | |
334 bool actually_sorted = true; | |
335 if(sort_locations == true) | |
336 { | |
337 // sort_index() uses std::sort() which may use quicksort... so we better | |
338 // make sure it's not already sorted before taking an O(N^2) sort penalty. | |
339 for (uword i = 1; i < locs.n_cols; ++i) | |
340 { | |
341 if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1))) | |
342 { | |
343 actually_sorted = false; | |
344 break; | |
345 } | |
346 } | |
347 | |
348 if(actually_sorted == false) | |
349 { | |
350 // This may not be the fastest possible implementation but it maximizes code reuse. | |
351 Col<uword> abslocs(locs.n_cols); | |
352 | |
353 for (uword i = 0; i < locs.n_cols; ++i) | |
354 { | |
355 abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i); | |
356 } | |
357 | |
358 // Now we will sort with sort_index(). | |
359 uvec sorted_indices = sort_index(abslocs); // Ascending sort. | |
360 | |
361 // Now we add the elements in this sorted order. | |
362 for (uword i = 0; i < sorted_indices.n_elem; ++i) | |
363 { | |
364 arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index"); | |
365 arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index"); | |
366 | |
367 access::rw(values[i]) = vals[sorted_indices[i]]; | |
368 access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]); | |
369 | |
370 access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++; | |
371 } | |
372 } | |
373 } | |
374 | |
375 if( (sort_locations == false) || (actually_sorted == true) ) | |
376 { | |
377 // Now set the values and row indices correctly. | |
378 // Increment the column pointers in each column (so they are column "counts"). | |
379 for (uword i = 0; i < vals.n_elem; ++i) | |
380 { | |
381 arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index"); | |
382 arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index"); | |
383 | |
384 // Check ordering in debug mode. | |
385 if(i > 0) | |
386 { | |
387 arma_debug_check | |
388 ( | |
389 ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ), | |
390 "SpMat::SpMat(): out of order points; either pass sort_locations = true or sort points in column-major ordering" | |
391 ); | |
392 arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list"); | |
393 } | |
394 | |
395 access::rw(values[i]) = vals[i]; | |
396 access::rw(row_indices[i]) = locs.at(0, i); | |
397 | |
398 access::rw(col_ptrs[locs.at(1, i) + 1])++; | |
399 } | |
400 } | |
401 | |
402 // Now fix the column pointers. | |
403 for (uword i = 0; i <= n_cols; ++i) | |
404 { | |
405 access::rw(col_ptrs[i + 1]) += col_ptrs[i]; | |
406 } | |
407 } | |
408 | |
409 | |
410 | |
411 /** | |
412 * Simple operators with plain values. These operate on every value in the | |
413 * matrix, so a sparse matrix += 1 will turn all those zeroes into ones. Be | |
414 * careful and make sure that's what you really want! | |
415 */ | |
416 template<typename eT> | |
417 inline | |
418 const SpMat<eT>& | |
419 SpMat<eT>::operator=(const eT val) | |
420 { | |
421 arma_extra_debug_sigprint(); | |
422 | |
423 // Resize to 1x1 then set that to the right value. | |
424 init(1, 1); // Sets col_ptrs to 0. | |
425 mem_resize(1); // One element. | |
426 | |
427 // Manually set element. | |
428 access::rw(values[0]) = val; | |
429 access::rw(row_indices[0]) = 0; | |
430 access::rw(col_ptrs[1]) = 1; | |
431 | |
432 return *this; | |
433 } | |
434 | |
435 | |
436 | |
437 template<typename eT> | |
438 inline | |
439 const SpMat<eT>& | |
440 SpMat<eT>::operator*=(const eT val) | |
441 { | |
442 arma_extra_debug_sigprint(); | |
443 | |
444 if(val == eT(0)) | |
445 { | |
446 // Everything will be zero. | |
447 init(n_rows, n_cols); | |
448 return *this; | |
449 } | |
450 | |
451 arrayops::inplace_mul( access::rwp(values), val, n_nonzero ); | |
452 | |
453 return *this; | |
454 } | |
455 | |
456 | |
457 | |
458 template<typename eT> | |
459 inline | |
460 const SpMat<eT>& | |
461 SpMat<eT>::operator/=(const eT val) | |
462 { | |
463 arma_extra_debug_sigprint(); | |
464 | |
465 arma_debug_check( (val == eT(0)), "element-wise division: division by zero" ); | |
466 | |
467 arrayops::inplace_div( access::rwp(values), val, n_nonzero ); | |
468 | |
469 return *this; | |
470 } | |
471 | |
472 | |
473 | |
474 template<typename eT> | |
475 inline | |
476 const SpMat<eT>& | |
477 SpMat<eT>::operator=(const SpMat<eT>& x) | |
478 { | |
479 arma_extra_debug_sigprint(); | |
480 | |
481 init(x); | |
482 | |
483 return *this; | |
484 } | |
485 | |
486 | |
487 | |
488 template<typename eT> | |
489 inline | |
490 const SpMat<eT>& | |
491 SpMat<eT>::operator+=(const SpMat<eT>& x) | |
492 { | |
493 arma_extra_debug_sigprint(); | |
494 | |
495 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "addition"); | |
496 | |
497 // Iterate over nonzero values of other matrix. | |
498 for (const_iterator it = x.begin(); it != x.end(); it++) | |
499 { | |
500 get_value(it.row(), it.col()) += *it; | |
501 } | |
502 | |
503 return *this; | |
504 } | |
505 | |
506 | |
507 | |
508 template<typename eT> | |
509 inline | |
510 const SpMat<eT>& | |
511 SpMat<eT>::operator-=(const SpMat<eT>& x) | |
512 { | |
513 arma_extra_debug_sigprint(); | |
514 | |
515 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "subtraction"); | |
516 | |
517 // Iterate over nonzero values of other matrix. | |
518 for (const_iterator it = x.begin(); it != x.end(); it++) | |
519 { | |
520 get_value(it.row(), it.col()) -= *it; | |
521 } | |
522 | |
523 return *this; | |
524 } | |
525 | |
526 | |
527 | |
528 template<typename eT> | |
529 inline | |
530 const SpMat<eT>& | |
531 SpMat<eT>::operator*=(const SpMat<eT>& y) | |
532 { | |
533 arma_extra_debug_sigprint(); | |
534 | |
535 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication"); | |
536 | |
537 SpMat<eT> z; | |
538 z = (*this) * y; | |
539 steal_mem(z); | |
540 | |
541 return *this; | |
542 } | |
543 | |
544 | |
545 | |
546 // This is in-place element-wise matrix multiplication. | |
547 template<typename eT> | |
548 inline | |
549 const SpMat<eT>& | |
550 SpMat<eT>::operator%=(const SpMat<eT>& x) | |
551 { | |
552 arma_extra_debug_sigprint(); | |
553 | |
554 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication"); | |
555 | |
556 // We can do this with two iterators rather simply. | |
557 iterator it = begin(); | |
558 const_iterator x_it = x.begin(); | |
559 | |
560 while (it != end() && x_it != x.end()) | |
561 { | |
562 // One of these will be further advanced than the other (or they will be at the same place). | |
563 if ((it.row() == x_it.row()) && (it.col() == x_it.col())) | |
564 { | |
565 // There is an element at this place in both matrices. Multiply. | |
566 (*it) *= (*x_it); | |
567 | |
568 // Now move on to the next position. | |
569 it++; | |
570 x_it++; | |
571 } | |
572 | |
573 else if ((it.col() < x_it.col()) || ((it.col() == x_it.col()) && (it.row() < x_it.row()))) | |
574 { | |
575 // This case is when our matrix has an element which the other matrix does not. | |
576 // So we must delete this element. | |
577 (*it) = 0; | |
578 | |
579 // Because we have deleted the element, we now have to manually set the position... | |
580 it.internal_pos--; | |
581 | |
582 // Now we can increment our iterator. | |
583 it++; | |
584 } | |
585 | |
586 else /* if our iterator is ahead of the other matrix */ | |
587 { | |
588 // In this case we don't need to set anything to 0; our element is already 0. | |
589 // We can just increment the iterator of the other matrix. | |
590 x_it++; | |
591 } | |
592 | |
593 } | |
594 | |
595 // If we are not at the end of our matrix, then we must terminate the remaining elements. | |
596 while (it != end()) | |
597 { | |
598 (*it) = 0; | |
599 | |
600 // Hack to manually set the position right... | |
601 it.internal_pos--; | |
602 it++; // ...and then an increment. | |
603 } | |
604 | |
605 return *this; | |
606 } | |
607 | |
608 | |
609 | |
610 // Construct a complex matrix out of two non-complex matrices | |
611 template<typename eT> | |
612 template<typename T1, typename T2> | |
613 inline | |
614 SpMat<eT>::SpMat | |
615 ( | |
616 const SpBase<typename SpMat<eT>::pod_type, T1>& A, | |
617 const SpBase<typename SpMat<eT>::pod_type, T2>& B | |
618 ) | |
619 : n_rows(0) | |
620 , n_cols(0) | |
621 , n_elem(0) | |
622 , n_nonzero(0) | |
623 , vec_state(0) | |
624 , values(NULL) // extra element is set when mem_resize is called | |
625 , row_indices(NULL) | |
626 , col_ptrs(NULL) | |
627 { | |
628 arma_extra_debug_sigprint(); | |
629 | |
630 typedef typename T1::elem_type T; | |
631 | |
632 // Make sure eT is complex and T is not (compile-time check). | |
633 arma_type_check(( is_complex<eT>::value == false )); | |
634 arma_type_check(( is_complex< T>::value == true )); | |
635 | |
636 // Compile-time abort if types are not compatible. | |
637 arma_type_check(( is_same_type< std::complex<T>, eT >::value == false )); | |
638 | |
639 const unwrap_spmat<T1> tmp1(A.get_ref()); | |
640 const unwrap_spmat<T2> tmp2(B.get_ref()); | |
641 | |
642 const SpMat<T>& X = tmp1.M; | |
643 const SpMat<T>& Y = tmp2.M; | |
644 | |
645 arma_debug_assert_same_size(X.n_rows, X.n_cols, Y.n_rows, Y.n_cols, "SpMat()"); | |
646 | |
647 const uword l_n_rows = X.n_rows; | |
648 const uword l_n_cols = X.n_cols; | |
649 | |
650 // Set size of matrix correctly. | |
651 init(l_n_rows, l_n_cols); | |
652 mem_resize(n_unique(X, Y, op_n_unique_count())); | |
653 | |
654 // Now on a second iteration, fill it. | |
655 typename SpMat<T>::const_iterator x_it = X.begin(); | |
656 typename SpMat<T>::const_iterator x_end = X.end(); | |
657 | |
658 typename SpMat<T>::const_iterator y_it = Y.begin(); | |
659 typename SpMat<T>::const_iterator y_end = Y.end(); | |
660 | |
661 uword cur_pos = 0; | |
662 | |
663 while ((x_it != x_end) || (y_it != y_end)) | |
664 { | |
665 if(x_it == y_it) // if we are at the same place | |
666 { | |
667 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, (T) *y_it); | |
668 access::rw(row_indices[cur_pos]) = x_it.row(); | |
669 ++access::rw(col_ptrs[x_it.col() + 1]); | |
670 | |
671 ++x_it; | |
672 ++y_it; | |
673 } | |
674 else | |
675 { | |
676 if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end | |
677 { | |
678 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, T(0)); | |
679 access::rw(row_indices[cur_pos]) = x_it.row(); | |
680 ++access::rw(col_ptrs[x_it.col() + 1]); | |
681 | |
682 ++x_it; | |
683 } | |
684 else // x is closer to the end | |
685 { | |
686 access::rw(values[cur_pos]) = std::complex<T>(T(0), (T) *y_it); | |
687 access::rw(row_indices[cur_pos]) = y_it.row(); | |
688 ++access::rw(col_ptrs[y_it.col() + 1]); | |
689 | |
690 ++y_it; | |
691 } | |
692 } | |
693 | |
694 ++cur_pos; | |
695 } | |
696 | |
697 // Now fix the column pointers; they are supposed to be a sum. | |
698 for (uword c = 1; c <= n_cols; ++c) | |
699 { | |
700 access::rw(col_ptrs[c]) += col_ptrs[c - 1]; | |
701 } | |
702 | |
703 } | |
704 | |
705 | |
706 | |
707 template<typename eT> | |
708 inline | |
709 const SpMat<eT>& | |
710 SpMat<eT>::operator/=(const SpMat<eT>& x) | |
711 { | |
712 arma_extra_debug_sigprint(); | |
713 | |
714 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); | |
715 | |
716 // If you use this method, you are probably stupid or misguided, but for compatibility with Mat, we have implemented it anyway. | |
717 // We have to loop over every element, which is not good. In fact, it makes me physically sad to write this. | |
718 for(uword c = 0; c < n_cols; ++c) | |
719 { | |
720 for(uword r = 0; r < n_rows; ++r) | |
721 { | |
722 at(r, c) /= x.at(r, c); | |
723 } | |
724 } | |
725 | |
726 return *this; | |
727 } | |
728 | |
729 | |
730 | |
731 template<typename eT> | |
732 template<typename T1> | |
733 inline | |
734 SpMat<eT>::SpMat(const Base<eT, T1>& x) | |
735 : n_rows(0) | |
736 , n_cols(0) | |
737 , n_elem(0) | |
738 , n_nonzero(0) | |
739 , vec_state(0) | |
740 , values(NULL) // extra element is set when mem_resize is called in operator=() | |
741 , row_indices(NULL) | |
742 , col_ptrs(NULL) | |
743 { | |
744 arma_extra_debug_sigprint_this(this); | |
745 | |
746 (*this).operator=(x); | |
747 } | |
748 | |
749 | |
750 | |
751 template<typename eT> | |
752 template<typename T1> | |
753 inline | |
754 const SpMat<eT>& | |
755 SpMat<eT>::operator=(const Base<eT, T1>& x) | |
756 { | |
757 arma_extra_debug_sigprint(); | |
758 | |
759 const Proxy<T1> p(x.get_ref()); | |
760 | |
761 const uword x_n_rows = p.get_n_rows(); | |
762 const uword x_n_cols = p.get_n_cols(); | |
763 const uword x_n_elem = p.get_n_elem(); | |
764 | |
765 init(x_n_rows, x_n_cols); | |
766 | |
767 // Count number of nonzero elements in base object. | |
768 uword n = 0; | |
769 if(Proxy<T1>::prefer_at_accessor == true) | |
770 { | |
771 for(uword j = 0; j < x_n_cols; ++j) | |
772 for(uword i = 0; i < x_n_rows; ++i) | |
773 { | |
774 if(p.at(i, j) != eT(0)) { ++n; } | |
775 } | |
776 } | |
777 else | |
778 { | |
779 for(uword i = 0; i < x_n_elem; ++i) | |
780 { | |
781 if(p[i] != eT(0)) { ++n; } | |
782 } | |
783 } | |
784 | |
785 mem_resize(n); | |
786 | |
787 // Now the memory is resized correctly; add nonzero elements. | |
788 n = 0; | |
789 for(uword j = 0; j < x_n_cols; ++j) | |
790 for(uword i = 0; i < x_n_rows; ++i) | |
791 { | |
792 const eT val = p.at(i, j); | |
793 | |
794 if(val != eT(0)) | |
795 { | |
796 access::rw(values[n]) = val; | |
797 access::rw(row_indices[n]) = i; | |
798 access::rw(col_ptrs[j + 1])++; | |
799 ++n; | |
800 } | |
801 } | |
802 | |
803 // Sum column counts to be column pointers. | |
804 for(uword c = 1; c <= n_cols; ++c) | |
805 { | |
806 access::rw(col_ptrs[c]) += col_ptrs[c - 1]; | |
807 } | |
808 | |
809 return *this; | |
810 } | |
811 | |
812 | |
813 | |
814 template<typename eT> | |
815 template<typename T1> | |
816 inline | |
817 const SpMat<eT>& | |
818 SpMat<eT>::operator*=(const Base<eT, T1>& y) | |
819 { | |
820 arma_extra_debug_sigprint(); | |
821 | |
822 const Proxy<T1> p(y.get_ref()); | |
823 | |
824 arma_debug_assert_mul_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "matrix multiplication"); | |
825 | |
826 // We assume the matrix structure is such that we will end up with a sparse | |
827 // matrix. Assuming that every entry in the dense matrix is nonzero (which is | |
828 // a fairly valid assumption), each row with any nonzero elements in it (in this | |
829 // matrix) implies an entire nonzero column. Therefore, we iterate over all | |
830 // the row_indices and count the number of rows with any elements in them | |
831 // (using the quasi-linked-list idea from SYMBMM -- see operator_times.hpp). | |
832 podarray<uword> index(n_rows); | |
833 index.fill(n_rows); // Fill with invalid links. | |
834 | |
835 uword last_index = n_rows + 1; | |
836 for(uword i = 0; i < n_nonzero; ++i) | |
837 { | |
838 if(index[row_indices[i]] == n_rows) | |
839 { | |
840 index[row_indices[i]] = last_index; | |
841 last_index = row_indices[i]; | |
842 } | |
843 } | |
844 | |
845 // Now count the number of rows which have nonzero elements. | |
846 uword nonzero_rows = 0; | |
847 while(last_index != n_rows + 1) | |
848 { | |
849 ++nonzero_rows; | |
850 last_index = index[last_index]; | |
851 } | |
852 | |
853 SpMat<eT> z(n_rows, p.get_n_cols()); | |
854 | |
855 z.mem_resize(nonzero_rows * p.get_n_cols()); // upper bound on size | |
856 | |
857 // Now we have to fill all the elements using a modification of the NUMBMM algorithm. | |
858 uword cur_pos = 0; | |
859 | |
860 podarray<eT> partial_sums(n_rows); | |
861 partial_sums.zeros(); | |
862 | |
863 for(uword lcol = 0; lcol < n_cols; ++lcol) | |
864 { | |
865 const_iterator it = begin(); | |
866 | |
867 while(it != end()) | |
868 { | |
869 const eT value = (*it); | |
870 | |
871 partial_sums[it.row()] += (value * p.at(it.col(), lcol)); | |
872 | |
873 ++it; | |
874 } | |
875 | |
876 // Now add all partial sums to the matrix. | |
877 for(uword i = 0; i < n_rows; ++i) | |
878 { | |
879 if(partial_sums[i] != eT(0)) | |
880 { | |
881 access::rw(z.values[cur_pos]) = partial_sums[i]; | |
882 access::rw(z.row_indices[cur_pos]) = i; | |
883 ++access::rw(z.col_ptrs[lcol + 1]); | |
884 //printf("colptr %d now %d\n", lcol + 1, z.col_ptrs[lcol + 1]); | |
885 ++cur_pos; | |
886 partial_sums[i] = 0; // Would it be faster to do this in batch later? | |
887 } | |
888 } | |
889 } | |
890 | |
891 // Now fix the column pointers. | |
892 for(uword c = 1; c <= z.n_cols; ++c) | |
893 { | |
894 access::rw(z.col_ptrs[c]) += z.col_ptrs[c - 1]; | |
895 } | |
896 | |
897 // Resize to final correct size. | |
898 z.mem_resize(z.col_ptrs[z.n_cols]); | |
899 | |
900 // Now take the memory of the temporary matrix. | |
901 steal_mem(z); | |
902 | |
903 return *this; | |
904 } | |
905 | |
906 | |
907 | |
908 /** | |
909 * Don't use this function. It's not mathematically well-defined and wastes | |
910 * cycles to trash all your data. This is dumb. | |
911 */ | |
912 template<typename eT> | |
913 template<typename T1> | |
914 inline | |
915 const SpMat<eT>& | |
916 SpMat<eT>::operator/=(const Base<eT, T1>& x) | |
917 { | |
918 arma_extra_debug_sigprint(); | |
919 | |
920 SpMat<eT> tmp = (*this) / x.get_ref(); | |
921 | |
922 steal_mem(tmp); | |
923 | |
924 return *this; | |
925 } | |
926 | |
927 | |
928 | |
929 template<typename eT> | |
930 template<typename T1> | |
931 inline | |
932 const SpMat<eT>& | |
933 SpMat<eT>::operator%=(const Base<eT, T1>& x) | |
934 { | |
935 arma_extra_debug_sigprint(); | |
936 | |
937 const Proxy<T1> p(x.get_ref()); | |
938 | |
939 arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication"); | |
940 | |
941 // Count the number of elements we will need. | |
942 SpMat<eT> tmp(n_rows, n_cols); | |
943 const_iterator it = begin(); | |
944 uword new_n_nonzero = 0; | |
945 | |
946 while(it != end()) | |
947 { | |
948 // prefer_at_accessor == false can't save us any work here | |
949 if(((*it) * p.at(it.row(), it.col())) != eT(0)) | |
950 { | |
951 ++new_n_nonzero; | |
952 } | |
953 ++it; | |
954 } | |
955 | |
956 // Resize. | |
957 tmp.mem_resize(new_n_nonzero); | |
958 | |
959 const_iterator c_it = begin(); | |
960 uword cur_pos = 0; | |
961 while(c_it != end()) | |
962 { | |
963 // prefer_at_accessor == false can't save us any work here | |
964 const eT val = (*c_it) * p.at(c_it.row(), c_it.col()); | |
965 if(val != eT(0)) | |
966 { | |
967 access::rw(tmp.values[cur_pos]) = val; | |
968 access::rw(tmp.row_indices[cur_pos]) = c_it.row(); | |
969 ++access::rw(tmp.col_ptrs[c_it.col() + 1]); | |
970 ++cur_pos; | |
971 } | |
972 | |
973 ++c_it; | |
974 } | |
975 | |
976 // Fix column pointers. | |
977 for(uword c = 1; c <= n_cols; ++c) | |
978 { | |
979 access::rw(tmp.col_ptrs[c]) += tmp.col_ptrs[c - 1]; | |
980 } | |
981 | |
982 steal_mem(tmp); | |
983 | |
984 return *this; | |
985 } | |
986 | |
987 | |
988 | |
989 /** | |
990 * Functions on subviews. | |
991 */ | |
992 template<typename eT> | |
993 inline | |
994 SpMat<eT>::SpMat(const SpSubview<eT>& X) | |
995 : n_rows(0) | |
996 , n_cols(0) | |
997 , n_elem(0) | |
998 , n_nonzero(0) | |
999 , vec_state(0) | |
1000 , values(NULL) // extra element added when mem_resize is called | |
1001 , row_indices(NULL) | |
1002 , col_ptrs(NULL) | |
1003 { | |
1004 arma_extra_debug_sigprint_this(this); | |
1005 | |
1006 (*this).operator=(X); | |
1007 } | |
1008 | |
1009 | |
1010 | |
1011 template<typename eT> | |
1012 inline | |
1013 const SpMat<eT>& | |
1014 SpMat<eT>::operator=(const SpSubview<eT>& X) | |
1015 { | |
1016 arma_extra_debug_sigprint(); | |
1017 | |
1018 const uword in_n_cols = X.n_cols; | |
1019 const uword in_n_rows = X.n_rows; | |
1020 | |
1021 const bool alias = (this == &(X.m)); | |
1022 | |
1023 if(alias == false) | |
1024 { | |
1025 init(in_n_rows, in_n_cols); | |
1026 | |
1027 const uword x_n_nonzero = X.n_nonzero; | |
1028 | |
1029 mem_resize(x_n_nonzero); | |
1030 | |
1031 typename SpSubview<eT>::const_iterator it = X.begin(); | |
1032 | |
1033 while(it != X.end()) | |
1034 { | |
1035 access::rw(row_indices[it.pos()]) = it.row(); | |
1036 access::rw(values[it.pos()]) = (*it); | |
1037 ++access::rw(col_ptrs[it.col() + 1]); | |
1038 ++it; | |
1039 } | |
1040 | |
1041 // Now sum column pointers. | |
1042 for(uword c = 1; c <= n_cols; ++c) | |
1043 { | |
1044 access::rw(col_ptrs[c]) += col_ptrs[c - 1]; | |
1045 } | |
1046 } | |
1047 else | |
1048 { | |
1049 // Create it in a temporary. | |
1050 SpMat<eT> tmp(X); | |
1051 | |
1052 steal_mem(tmp); | |
1053 } | |
1054 | |
1055 return *this; | |
1056 } | |
1057 | |
1058 | |
1059 | |
1060 template<typename eT> | |
1061 inline | |
1062 const SpMat<eT>& | |
1063 SpMat<eT>::operator+=(const SpSubview<eT>& X) | |
1064 { | |
1065 arma_extra_debug_sigprint(); | |
1066 | |
1067 arma_debug_assert_same_size(n_rows, n_cols, X.n_rows, X.n_cols, "addition"); | |
1068 | |
1069 typename SpSubview<eT>::const_iterator it = X.begin(); | |
1070 | |
1071 while(it != X.end()) | |
1072 { | |
1073 at(it.row(), it.col()) += (*it); | |
1074 ++it; | |
1075 } | |
1076 | |
1077 return *this; | |
1078 } | |
1079 | |
1080 | |
1081 | |
1082 template<typename eT> | |
1083 inline | |
1084 const SpMat<eT>& | |
1085 SpMat<eT>::operator-=(const SpSubview<eT>& X) | |
1086 { | |
1087 arma_extra_debug_sigprint(); | |
1088 | |
1089 arma_debug_assert_same_size(n_rows, n_cols, X.n_rows, X.n_cols, "subtraction"); | |
1090 | |
1091 typename SpSubview<eT>::const_iterator it = X.begin(); | |
1092 | |
1093 while(it != X.end()) | |
1094 { | |
1095 at(it.row(), it.col()) -= (*it); | |
1096 ++it; | |
1097 } | |
1098 | |
1099 return *this; | |
1100 } | |
1101 | |
1102 | |
1103 | |
1104 template<typename eT> | |
1105 inline | |
1106 const SpMat<eT>& | |
1107 SpMat<eT>::operator*=(const SpSubview<eT>& y) | |
1108 { | |
1109 arma_extra_debug_sigprint(); | |
1110 | |
1111 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication"); | |
1112 | |
1113 // Cannot be done in-place (easily). | |
1114 SpMat<eT> z = (*this) * y; | |
1115 steal_mem(z); | |
1116 | |
1117 return *this; | |
1118 } | |
1119 | |
1120 | |
1121 | |
1122 template<typename eT> | |
1123 inline | |
1124 const SpMat<eT>& | |
1125 SpMat<eT>::operator%=(const SpSubview<eT>& x) | |
1126 { | |
1127 arma_extra_debug_sigprint(); | |
1128 | |
1129 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication"); | |
1130 | |
1131 iterator it = begin(); | |
1132 typename SpSubview<eT>::const_iterator xit = x.begin(); | |
1133 | |
1134 while((it != end()) || (xit != x.end())) | |
1135 { | |
1136 if((xit.row() == it.row()) && (xit.col() == it.col())) | |
1137 { | |
1138 (*it) *= (*xit); | |
1139 ++it; | |
1140 ++xit; | |
1141 } | |
1142 else | |
1143 { | |
1144 if((xit.col() > it.col()) || ((xit.col() == it.col()) && (xit.row() > it.row()))) | |
1145 { | |
1146 // xit is "ahead" | |
1147 (*it) = eT(0); // erase element; x has a zero here | |
1148 it.internal_pos--; // update iterator so it still works | |
1149 ++it; | |
1150 } | |
1151 else | |
1152 { | |
1153 // it is "ahead" | |
1154 ++xit; | |
1155 } | |
1156 } | |
1157 } | |
1158 | |
1159 return *this; | |
1160 } | |
1161 | |
1162 | |
1163 template<typename eT> | |
1164 inline | |
1165 const SpMat<eT>& | |
1166 SpMat<eT>::operator/=(const SpSubview<eT>& x) | |
1167 { | |
1168 arma_extra_debug_sigprint(); | |
1169 | |
1170 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); | |
1171 | |
1172 // There is no pretty way to do this. | |
1173 for(uword elem = 0; elem < n_elem; elem++) | |
1174 { | |
1175 at(elem) /= x(elem); | |
1176 } | |
1177 | |
1178 return *this; | |
1179 } | |
1180 | |
1181 | |
1182 | |
1183 /** | |
1184 * Operators on regular subviews. | |
1185 */ | |
1186 template<typename eT> | |
1187 inline | |
1188 SpMat<eT>::SpMat(const subview<eT>& x) | |
1189 : n_rows(0) | |
1190 , n_cols(0) | |
1191 , n_elem(0) | |
1192 , n_nonzero(0) | |
1193 , vec_state(0) | |
1194 , values(NULL) // extra value set in operator=() | |
1195 , row_indices(NULL) | |
1196 , col_ptrs(NULL) | |
1197 { | |
1198 arma_extra_debug_sigprint_this(this); | |
1199 | |
1200 (*this).operator=(x); | |
1201 } | |
1202 | |
1203 | |
1204 | |
1205 template<typename eT> | |
1206 inline | |
1207 const SpMat<eT>& | |
1208 SpMat<eT>::operator=(const subview<eT>& x) | |
1209 { | |
1210 arma_extra_debug_sigprint(); | |
1211 | |
1212 const uword x_n_rows = x.n_rows; | |
1213 const uword x_n_cols = x.n_cols; | |
1214 | |
1215 // Set the size correctly. | |
1216 init(x_n_rows, x_n_cols); | |
1217 | |
1218 // Count number of nonzero elements. | |
1219 uword n = 0; | |
1220 for(uword c = 0; c < x_n_cols; ++c) | |
1221 { | |
1222 for(uword r = 0; r < x_n_rows; ++r) | |
1223 { | |
1224 if(x.at(r, c) != eT(0)) | |
1225 { | |
1226 ++n; | |
1227 } | |
1228 } | |
1229 } | |
1230 | |
1231 // Resize memory appropriately. | |
1232 mem_resize(n); | |
1233 | |
1234 n = 0; | |
1235 for(uword c = 0; c < x_n_cols; ++c) | |
1236 { | |
1237 for(uword r = 0; r < x_n_rows; ++r) | |
1238 { | |
1239 const eT val = x.at(r, c); | |
1240 | |
1241 if(val != eT(0)) | |
1242 { | |
1243 access::rw(values[n]) = val; | |
1244 access::rw(row_indices[n]) = r; | |
1245 ++access::rw(col_ptrs[c + 1]); | |
1246 ++n; | |
1247 } | |
1248 } | |
1249 } | |
1250 | |
1251 // Fix column counts into column pointers. | |
1252 for(uword c = 1; c <= n_cols; ++c) | |
1253 { | |
1254 access::rw(col_ptrs[c]) += col_ptrs[c - 1]; | |
1255 } | |
1256 | |
1257 return *this; | |
1258 } | |
1259 | |
1260 | |
1261 | |
1262 template<typename eT> | |
1263 inline | |
1264 const SpMat<eT>& | |
1265 SpMat<eT>::operator+=(const subview<eT>& x) | |
1266 { | |
1267 arma_extra_debug_sigprint(); | |
1268 | |
1269 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "addition"); | |
1270 | |
1271 // Loop over every element. This could probably be written in a more | |
1272 // efficient way, by calculating the number of nonzero elements the output | |
1273 // matrix will have, allocating the memory correctly, and then filling the | |
1274 // matrix correctly. However... for now, this works okay. | |
1275 for(uword lcol = 0; lcol < n_cols; ++lcol) | |
1276 for(uword lrow = 0; lrow < n_rows; ++lrow) | |
1277 { | |
1278 at(lrow, lcol) += x.at(lrow, lcol); | |
1279 } | |
1280 | |
1281 return *this; | |
1282 } | |
1283 | |
1284 | |
1285 | |
1286 template<typename eT> | |
1287 inline | |
1288 const SpMat<eT>& | |
1289 SpMat<eT>::operator-=(const subview<eT>& x) | |
1290 { | |
1291 arma_extra_debug_sigprint(); | |
1292 | |
1293 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "subtraction"); | |
1294 | |
1295 // Loop over every element. | |
1296 for(uword lcol = 0; lcol < n_cols; ++lcol) | |
1297 for(uword lrow = 0; lrow < n_rows; ++lrow) | |
1298 { | |
1299 at(lrow, lcol) -= x.at(lrow, lcol); | |
1300 } | |
1301 | |
1302 return *this; | |
1303 } | |
1304 | |
1305 | |
1306 | |
1307 template<typename eT> | |
1308 inline | |
1309 const SpMat<eT>& | |
1310 SpMat<eT>::operator*=(const subview<eT>& y) | |
1311 { | |
1312 arma_extra_debug_sigprint(); | |
1313 | |
1314 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication"); | |
1315 | |
1316 SpMat<eT> z(n_rows, y.n_cols); | |
1317 | |
1318 // Performed in the same fashion as operator*=(SpMat). | |
1319 for (const_row_iterator x_row_it = begin_row(); x_row_it.pos() < n_nonzero; ++x_row_it) | |
1320 { | |
1321 for (uword lcol = 0; lcol < y.n_cols; ++lcol) | |
1322 { | |
1323 // At this moment in the loop, we are calculating anything that is contributed to by *x_row_it and *y_col_it. | |
1324 // Given that our position is x_ab and y_bc, there will only be a contribution if x.col == y.row, and that | |
1325 // contribution will be in location z_ac. | |
1326 z.at(x_row_it.row, lcol) += (*x_row_it) * y.at(x_row_it.col, lcol); | |
1327 } | |
1328 } | |
1329 | |
1330 steal_mem(z); | |
1331 | |
1332 return *this; | |
1333 } | |
1334 | |
1335 | |
1336 | |
1337 template<typename eT> | |
1338 inline | |
1339 const SpMat<eT>& | |
1340 SpMat<eT>::operator%=(const subview<eT>& x) | |
1341 { | |
1342 arma_extra_debug_sigprint(); | |
1343 | |
1344 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication"); | |
1345 | |
1346 // Loop over every element. | |
1347 for(uword lcol = 0; lcol < n_cols; ++lcol) | |
1348 for(uword lrow = 0; lrow < n_rows; ++lrow) | |
1349 { | |
1350 at(lrow, lcol) *= x.at(lrow, lcol); | |
1351 } | |
1352 | |
1353 return *this; | |
1354 } | |
1355 | |
1356 | |
1357 | |
1358 template<typename eT> | |
1359 inline | |
1360 const SpMat<eT>& | |
1361 SpMat<eT>::operator/=(const subview<eT>& x) | |
1362 { | |
1363 arma_extra_debug_sigprint(); | |
1364 | |
1365 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); | |
1366 | |
1367 // Loop over every element. | |
1368 for(uword lcol = 0; lcol < n_cols; ++lcol) | |
1369 for(uword lrow = 0; lrow < n_rows; ++lrow) | |
1370 { | |
1371 at(lrow, lcol) /= x.at(lrow, lcol); | |
1372 } | |
1373 | |
1374 return *this; | |
1375 } | |
1376 | |
1377 | |
1378 | |
1379 template<typename eT> | |
1380 template<typename T1, typename spop_type> | |
1381 inline | |
1382 SpMat<eT>::SpMat(const SpOp<T1, spop_type>& X) | |
1383 : n_rows(0) | |
1384 , n_cols(0) | |
1385 , n_elem(0) | |
1386 , n_nonzero(0) | |
1387 , vec_state(0) | |
1388 , values(NULL) // set in application of sparse operation | |
1389 , row_indices(NULL) | |
1390 , col_ptrs(NULL) | |
1391 { | |
1392 arma_extra_debug_sigprint_this(this); | |
1393 | |
1394 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1395 | |
1396 spop_type::apply(*this, X); | |
1397 } | |
1398 | |
1399 | |
1400 | |
1401 template<typename eT> | |
1402 template<typename T1, typename spop_type> | |
1403 inline | |
1404 const SpMat<eT>& | |
1405 SpMat<eT>::operator=(const SpOp<T1, spop_type>& X) | |
1406 { | |
1407 arma_extra_debug_sigprint(); | |
1408 | |
1409 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1410 | |
1411 spop_type::apply(*this, X); | |
1412 | |
1413 return *this; | |
1414 } | |
1415 | |
1416 | |
1417 | |
1418 template<typename eT> | |
1419 template<typename T1, typename spop_type> | |
1420 inline | |
1421 const SpMat<eT>& | |
1422 SpMat<eT>::operator+=(const SpOp<T1, spop_type>& X) | |
1423 { | |
1424 arma_extra_debug_sigprint(); | |
1425 | |
1426 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1427 | |
1428 const SpMat<eT> m(X); | |
1429 | |
1430 return (*this).operator+=(m); | |
1431 } | |
1432 | |
1433 | |
1434 | |
1435 template<typename eT> | |
1436 template<typename T1, typename spop_type> | |
1437 inline | |
1438 const SpMat<eT>& | |
1439 SpMat<eT>::operator-=(const SpOp<T1, spop_type>& X) | |
1440 { | |
1441 arma_extra_debug_sigprint(); | |
1442 | |
1443 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1444 | |
1445 const SpMat<eT> m(X); | |
1446 | |
1447 return (*this).operator-=(m); | |
1448 } | |
1449 | |
1450 | |
1451 | |
1452 template<typename eT> | |
1453 template<typename T1, typename spop_type> | |
1454 inline | |
1455 const SpMat<eT>& | |
1456 SpMat<eT>::operator*=(const SpOp<T1, spop_type>& X) | |
1457 { | |
1458 arma_extra_debug_sigprint(); | |
1459 | |
1460 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1461 | |
1462 const SpMat<eT> m(X); | |
1463 | |
1464 return (*this).operator*=(m); | |
1465 } | |
1466 | |
1467 | |
1468 | |
1469 template<typename eT> | |
1470 template<typename T1, typename spop_type> | |
1471 inline | |
1472 const SpMat<eT>& | |
1473 SpMat<eT>::operator%=(const SpOp<T1, spop_type>& X) | |
1474 { | |
1475 arma_extra_debug_sigprint(); | |
1476 | |
1477 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1478 | |
1479 const SpMat<eT> m(X); | |
1480 | |
1481 return (*this).operator%=(m); | |
1482 } | |
1483 | |
1484 | |
1485 | |
1486 template<typename eT> | |
1487 template<typename T1, typename spop_type> | |
1488 inline | |
1489 const SpMat<eT>& | |
1490 SpMat<eT>::operator/=(const SpOp<T1, spop_type>& X) | |
1491 { | |
1492 arma_extra_debug_sigprint(); | |
1493 | |
1494 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1495 | |
1496 const SpMat<eT> m(X); | |
1497 | |
1498 return (*this).operator/=(m); | |
1499 } | |
1500 | |
1501 | |
1502 | |
1503 template<typename eT> | |
1504 template<typename T1, typename T2, typename spglue_type> | |
1505 inline | |
1506 SpMat<eT>::SpMat(const SpGlue<T1, T2, spglue_type>& X) | |
1507 : n_rows(0) | |
1508 , n_cols(0) | |
1509 , n_elem(0) | |
1510 , n_nonzero(0) | |
1511 , vec_state(0) | |
1512 , values(NULL) // extra element set in application of sparse glue | |
1513 , row_indices(NULL) | |
1514 , col_ptrs(NULL) | |
1515 { | |
1516 arma_extra_debug_sigprint_this(this); | |
1517 | |
1518 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1519 | |
1520 spglue_type::apply(*this, X); | |
1521 } | |
1522 | |
1523 | |
1524 | |
1525 template<typename eT> | |
1526 template<typename T1, typename spop_type> | |
1527 inline | |
1528 SpMat<eT>::SpMat(const mtSpOp<eT, T1, spop_type>& X) | |
1529 : n_rows(0) | |
1530 , n_cols(0) | |
1531 , n_elem(0) | |
1532 , n_nonzero(0) | |
1533 , vec_state(0) | |
1534 , values(NULL) // extra element set in application of sparse glue | |
1535 , row_indices(NULL) | |
1536 , col_ptrs(NULL) | |
1537 { | |
1538 arma_extra_debug_sigprint_this(this); | |
1539 | |
1540 spop_type::apply(*this, X); | |
1541 } | |
1542 | |
1543 | |
1544 | |
1545 template<typename eT> | |
1546 template<typename T1, typename spop_type> | |
1547 inline | |
1548 const SpMat<eT>& | |
1549 SpMat<eT>::operator=(const mtSpOp<eT, T1, spop_type>& X) | |
1550 { | |
1551 arma_extra_debug_sigprint(); | |
1552 | |
1553 spop_type::apply(*this, X); | |
1554 | |
1555 return *this; | |
1556 } | |
1557 | |
1558 | |
1559 | |
1560 template<typename eT> | |
1561 template<typename T1, typename spop_type> | |
1562 inline | |
1563 const SpMat<eT>& | |
1564 SpMat<eT>::operator+=(const mtSpOp<eT, T1, spop_type>& X) | |
1565 { | |
1566 arma_extra_debug_sigprint(); | |
1567 | |
1568 const SpMat<eT> m(X); | |
1569 | |
1570 return (*this).operator+=(m); | |
1571 } | |
1572 | |
1573 | |
1574 | |
1575 template<typename eT> | |
1576 template<typename T1, typename spop_type> | |
1577 inline | |
1578 const SpMat<eT>& | |
1579 SpMat<eT>::operator-=(const mtSpOp<eT, T1, spop_type>& X) | |
1580 { | |
1581 arma_extra_debug_sigprint(); | |
1582 | |
1583 const SpMat<eT> m(X); | |
1584 | |
1585 return (*this).operator-=(m); | |
1586 } | |
1587 | |
1588 | |
1589 | |
1590 template<typename eT> | |
1591 template<typename T1, typename spop_type> | |
1592 inline | |
1593 const SpMat<eT>& | |
1594 SpMat<eT>::operator*=(const mtSpOp<eT, T1, spop_type>& X) | |
1595 { | |
1596 arma_extra_debug_sigprint(); | |
1597 | |
1598 const SpMat<eT> m(X); | |
1599 | |
1600 return (*this).operator*=(m); | |
1601 } | |
1602 | |
1603 | |
1604 | |
1605 template<typename eT> | |
1606 template<typename T1, typename spop_type> | |
1607 inline | |
1608 const SpMat<eT>& | |
1609 SpMat<eT>::operator%=(const mtSpOp<eT, T1, spop_type>& X) | |
1610 { | |
1611 arma_extra_debug_sigprint(); | |
1612 | |
1613 const SpMat<eT> m(X); | |
1614 | |
1615 return (*this).operator%=(m); | |
1616 } | |
1617 | |
1618 | |
1619 | |
1620 template<typename eT> | |
1621 template<typename T1, typename spop_type> | |
1622 inline | |
1623 const SpMat<eT>& | |
1624 SpMat<eT>::operator/=(const mtSpOp<eT, T1, spop_type>& X) | |
1625 { | |
1626 arma_extra_debug_sigprint(); | |
1627 | |
1628 const SpMat<eT> m(X); | |
1629 | |
1630 return (*this).operator/=(m); | |
1631 } | |
1632 | |
1633 | |
1634 | |
1635 template<typename eT> | |
1636 template<typename T1, typename T2, typename spglue_type> | |
1637 inline | |
1638 const SpMat<eT>& | |
1639 SpMat<eT>::operator=(const SpGlue<T1, T2, spglue_type>& X) | |
1640 { | |
1641 arma_extra_debug_sigprint(); | |
1642 | |
1643 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1644 | |
1645 spglue_type::apply(*this, X); | |
1646 | |
1647 return *this; | |
1648 } | |
1649 | |
1650 | |
1651 | |
1652 template<typename eT> | |
1653 template<typename T1, typename T2, typename spglue_type> | |
1654 inline | |
1655 const SpMat<eT>& | |
1656 SpMat<eT>::operator+=(const SpGlue<T1, T2, spglue_type>& X) | |
1657 { | |
1658 arma_extra_debug_sigprint(); | |
1659 | |
1660 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1661 | |
1662 const SpMat<eT> m(X); | |
1663 | |
1664 return (*this).operator+=(m); | |
1665 } | |
1666 | |
1667 | |
1668 | |
1669 template<typename eT> | |
1670 template<typename T1, typename T2, typename spglue_type> | |
1671 inline | |
1672 const SpMat<eT>& | |
1673 SpMat<eT>::operator-=(const SpGlue<T1, T2, spglue_type>& X) | |
1674 { | |
1675 arma_extra_debug_sigprint(); | |
1676 | |
1677 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1678 | |
1679 const SpMat<eT> m(X); | |
1680 | |
1681 return (*this).operator-=(m); | |
1682 } | |
1683 | |
1684 | |
1685 | |
1686 template<typename eT> | |
1687 template<typename T1, typename T2, typename spglue_type> | |
1688 inline | |
1689 const SpMat<eT>& | |
1690 SpMat<eT>::operator*=(const SpGlue<T1, T2, spglue_type>& X) | |
1691 { | |
1692 arma_extra_debug_sigprint(); | |
1693 | |
1694 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1695 | |
1696 const SpMat<eT> m(X); | |
1697 | |
1698 return (*this).operator*=(m); | |
1699 } | |
1700 | |
1701 | |
1702 | |
1703 template<typename eT> | |
1704 template<typename T1, typename T2, typename spglue_type> | |
1705 inline | |
1706 const SpMat<eT>& | |
1707 SpMat<eT>::operator%=(const SpGlue<T1, T2, spglue_type>& X) | |
1708 { | |
1709 arma_extra_debug_sigprint(); | |
1710 | |
1711 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1712 | |
1713 const SpMat<eT> m(X); | |
1714 | |
1715 return (*this).operator%=(m); | |
1716 } | |
1717 | |
1718 | |
1719 | |
1720 template<typename eT> | |
1721 template<typename T1, typename T2, typename spglue_type> | |
1722 inline | |
1723 const SpMat<eT>& | |
1724 SpMat<eT>::operator/=(const SpGlue<T1, T2, spglue_type>& X) | |
1725 { | |
1726 arma_extra_debug_sigprint(); | |
1727 | |
1728 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false )); | |
1729 | |
1730 const SpMat<eT> m(X); | |
1731 | |
1732 return (*this).operator/=(m); | |
1733 } | |
1734 | |
1735 | |
1736 | |
1737 template<typename eT> | |
1738 arma_inline | |
1739 SpSubview<eT> | |
1740 SpMat<eT>::row(const uword row_num) | |
1741 { | |
1742 arma_extra_debug_sigprint(); | |
1743 | |
1744 arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds"); | |
1745 | |
1746 return SpSubview<eT>(*this, row_num, 0, 1, n_cols); | |
1747 } | |
1748 | |
1749 | |
1750 | |
1751 template<typename eT> | |
1752 arma_inline | |
1753 const SpSubview<eT> | |
1754 SpMat<eT>::row(const uword row_num) const | |
1755 { | |
1756 arma_extra_debug_sigprint(); | |
1757 | |
1758 arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds"); | |
1759 | |
1760 return SpSubview<eT>(*this, row_num, 0, 1, n_cols); | |
1761 } | |
1762 | |
1763 | |
1764 | |
1765 template<typename eT> | |
1766 inline | |
1767 SpSubview<eT> | |
1768 SpMat<eT>::operator()(const uword row_num, const span& col_span) | |
1769 { | |
1770 arma_extra_debug_sigprint(); | |
1771 | |
1772 const bool col_all = col_span.whole; | |
1773 | |
1774 const uword local_n_cols = n_cols; | |
1775 | |
1776 const uword in_col1 = col_all ? 0 : col_span.a; | |
1777 const uword in_col2 = col_span.b; | |
1778 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; | |
1779 | |
1780 arma_debug_check | |
1781 ( | |
1782 (row_num >= n_rows) | |
1783 || | |
1784 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) | |
1785 , | |
1786 "SpMat::operator(): indices out of bounds or incorrectly used" | |
1787 ); | |
1788 | |
1789 return SpSubview<eT>(*this, row_num, in_col1, 1, submat_n_cols); | |
1790 } | |
1791 | |
1792 | |
1793 | |
1794 template<typename eT> | |
1795 inline | |
1796 const SpSubview<eT> | |
1797 SpMat<eT>::operator()(const uword row_num, const span& col_span) const | |
1798 { | |
1799 arma_extra_debug_sigprint(); | |
1800 | |
1801 const bool col_all = col_span.whole; | |
1802 | |
1803 const uword local_n_cols = n_cols; | |
1804 | |
1805 const uword in_col1 = col_all ? 0 : col_span.a; | |
1806 const uword in_col2 = col_span.b; | |
1807 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; | |
1808 | |
1809 arma_debug_check | |
1810 ( | |
1811 (row_num >= n_rows) | |
1812 || | |
1813 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) | |
1814 , | |
1815 "SpMat::operator(): indices out of bounds or incorrectly used" | |
1816 ); | |
1817 | |
1818 return SpSubview<eT>(*this, row_num, in_col1, 1, submat_n_cols); | |
1819 } | |
1820 | |
1821 | |
1822 | |
1823 template<typename eT> | |
1824 arma_inline | |
1825 SpSubview<eT> | |
1826 SpMat<eT>::col(const uword col_num) | |
1827 { | |
1828 arma_extra_debug_sigprint(); | |
1829 | |
1830 arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds"); | |
1831 | |
1832 return SpSubview<eT>(*this, 0, col_num, n_rows, 1); | |
1833 } | |
1834 | |
1835 | |
1836 | |
1837 template<typename eT> | |
1838 arma_inline | |
1839 const SpSubview<eT> | |
1840 SpMat<eT>::col(const uword col_num) const | |
1841 { | |
1842 arma_extra_debug_sigprint(); | |
1843 | |
1844 arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds"); | |
1845 | |
1846 return SpSubview<eT>(*this, 0, col_num, n_rows, 1); | |
1847 } | |
1848 | |
1849 | |
1850 | |
1851 template<typename eT> | |
1852 inline | |
1853 SpSubview<eT> | |
1854 SpMat<eT>::operator()(const span& row_span, const uword col_num) | |
1855 { | |
1856 arma_extra_debug_sigprint(); | |
1857 | |
1858 const bool row_all = row_span.whole; | |
1859 | |
1860 const uword local_n_rows = n_rows; | |
1861 | |
1862 const uword in_row1 = row_all ? 0 : row_span.a; | |
1863 const uword in_row2 = row_span.b; | |
1864 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; | |
1865 | |
1866 arma_debug_check | |
1867 ( | |
1868 (col_num >= n_cols) | |
1869 || | |
1870 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) | |
1871 , | |
1872 "SpMat::operator(): indices out of bounds or incorrectly used" | |
1873 ); | |
1874 | |
1875 return SpSubview<eT>(*this, in_row1, col_num, submat_n_rows, 1); | |
1876 } | |
1877 | |
1878 | |
1879 | |
1880 template<typename eT> | |
1881 inline | |
1882 const SpSubview<eT> | |
1883 SpMat<eT>::operator()(const span& row_span, const uword col_num) const | |
1884 { | |
1885 arma_extra_debug_sigprint(); | |
1886 | |
1887 const bool row_all = row_span.whole; | |
1888 | |
1889 const uword local_n_rows = n_rows; | |
1890 | |
1891 const uword in_row1 = row_all ? 0 : row_span.a; | |
1892 const uword in_row2 = row_span.b; | |
1893 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; | |
1894 | |
1895 arma_debug_check | |
1896 ( | |
1897 (col_num >= n_cols) | |
1898 || | |
1899 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) | |
1900 , | |
1901 "SpMat::operator(): indices out of bounds or incorrectly used" | |
1902 ); | |
1903 | |
1904 return SpSubview<eT>(*this, in_row1, col_num, submat_n_rows, 1); | |
1905 } | |
1906 | |
1907 | |
1908 | |
1909 /** | |
1910 * Swap in_row1 with in_row2. | |
1911 */ | |
1912 template<typename eT> | |
1913 inline | |
1914 void | |
1915 SpMat<eT>::swap_rows(const uword in_row1, const uword in_row2) | |
1916 { | |
1917 arma_extra_debug_sigprint(); | |
1918 | |
1919 arma_debug_check | |
1920 ( | |
1921 (in_row1 >= n_rows) || (in_row2 >= n_rows), | |
1922 "SpMat::swap_rows(): out of bounds" | |
1923 ); | |
1924 | |
1925 // Sanity check. | |
1926 if (in_row1 == in_row2) | |
1927 { | |
1928 return; | |
1929 } | |
1930 | |
1931 // The easier way to do this, instead of collecting all the elements in one row and then swapping with the other, will be | |
1932 // to iterate over each column of the matrix (since we store in column-major format) and then swap the two elements in the two rows at that time. | |
1933 // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position. | |
1934 uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2; | |
1935 uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1; | |
1936 | |
1937 for (uword lcol = 0; lcol < n_cols; lcol++) | |
1938 { | |
1939 // If there is nothing in this column we can ignore it. | |
1940 if (col_ptrs[lcol] == col_ptrs[lcol + 1]) | |
1941 { | |
1942 continue; | |
1943 } | |
1944 | |
1945 // These will represent the positions of the items themselves. | |
1946 uword loc1 = n_nonzero + 1; | |
1947 uword loc2 = n_nonzero + 1; | |
1948 | |
1949 for (uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++) | |
1950 { | |
1951 if (row_indices[search_pos] == col1) | |
1952 { | |
1953 loc1 = search_pos; | |
1954 } | |
1955 | |
1956 if (row_indices[search_pos] == col2) | |
1957 { | |
1958 loc2 = search_pos; | |
1959 break; // No need to look any further. | |
1960 } | |
1961 } | |
1962 | |
1963 // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements. | |
1964 // If we found zero elements no work needs to be done and we can continue to the next column. | |
1965 if ((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1))) | |
1966 { | |
1967 // This is an easy case: just swap the values. No index modifying necessary. | |
1968 eT tmp = values[loc1]; | |
1969 access::rw(values[loc1]) = values[loc2]; | |
1970 access::rw(values[loc2]) = tmp; | |
1971 } | |
1972 else if (loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2. | |
1973 { | |
1974 // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1. | |
1975 // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be. | |
1976 while (((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2)) | |
1977 { | |
1978 // Swap both the values and the indices. The column should not change. | |
1979 eT tmp = values[loc1]; | |
1980 access::rw(values[loc1]) = values[loc1 + 1]; | |
1981 access::rw(values[loc1 + 1]) = tmp; | |
1982 | |
1983 uword tmp_index = row_indices[loc1]; | |
1984 access::rw(row_indices[loc1]) = row_indices[loc1 + 1]; | |
1985 access::rw(row_indices[loc1 + 1]) = tmp_index; | |
1986 | |
1987 loc1++; // And increment the counter. | |
1988 } | |
1989 | |
1990 // Now set the row index correctly. | |
1991 access::rw(row_indices[loc1]) = in_row2; | |
1992 | |
1993 } | |
1994 else if (loc2 != (n_nonzero + 1)) | |
1995 { | |
1996 // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2. | |
1997 // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be. | |
1998 while (((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1)) | |
1999 { | |
2000 // Swap both the values and the indices. The column should not change. | |
2001 eT tmp = values[loc2]; | |
2002 access::rw(values[loc2]) = values[loc2 - 1]; | |
2003 access::rw(values[loc2 - 1]) = tmp; | |
2004 | |
2005 uword tmp_index = row_indices[loc2]; | |
2006 access::rw(row_indices[loc2]) = row_indices[loc2 - 1]; | |
2007 access::rw(row_indices[loc2 - 1]) = tmp_index; | |
2008 | |
2009 loc2--; // And decrement the counter. | |
2010 } | |
2011 | |
2012 // Now set the row index correctly. | |
2013 access::rw(row_indices[loc2]) = in_row1; | |
2014 | |
2015 } | |
2016 /* else: no need to swap anything; both values are zero */ | |
2017 } | |
2018 } | |
2019 | |
2020 /** | |
2021 * Swap in_col1 with in_col2. | |
2022 */ | |
2023 template<typename eT> | |
2024 inline | |
2025 void | |
2026 SpMat<eT>::swap_cols(const uword in_col1, const uword in_col2) | |
2027 { | |
2028 arma_extra_debug_sigprint(); | |
2029 | |
2030 // slow but works | |
2031 for(uword lrow = 0; lrow < n_rows; ++lrow) | |
2032 { | |
2033 eT tmp = at(lrow, in_col1); | |
2034 at(lrow, in_col1) = at(lrow, in_col2); | |
2035 at(lrow, in_col2) = tmp; | |
2036 } | |
2037 } | |
2038 | |
2039 /** | |
2040 * Remove the row row_num. | |
2041 */ | |
2042 template<typename eT> | |
2043 inline | |
2044 void | |
2045 SpMat<eT>::shed_row(const uword row_num) | |
2046 { | |
2047 arma_extra_debug_sigprint(); | |
2048 arma_debug_check (row_num >= n_rows, "SpMat::shed_row(): out of bounds"); | |
2049 | |
2050 shed_rows (row_num, row_num); | |
2051 } | |
2052 | |
2053 /** | |
2054 * Remove the column col_num. | |
2055 */ | |
2056 template<typename eT> | |
2057 inline | |
2058 void | |
2059 SpMat<eT>::shed_col(const uword col_num) | |
2060 { | |
2061 arma_extra_debug_sigprint(); | |
2062 arma_debug_check (col_num >= n_cols, "SpMat::shed_col(): out of bounds"); | |
2063 | |
2064 shed_cols(col_num, col_num); | |
2065 } | |
2066 | |
2067 /** | |
2068 * Remove all rows between (and including) in_row1 and in_row2. | |
2069 */ | |
2070 template<typename eT> | |
2071 inline | |
2072 void | |
2073 SpMat<eT>::shed_rows(const uword in_row1, const uword in_row2) | |
2074 { | |
2075 arma_extra_debug_sigprint(); | |
2076 | |
2077 arma_debug_check | |
2078 ( | |
2079 (in_row1 > in_row2) || (in_row2 >= n_rows), | |
2080 "SpMat::shed_rows(): indices out of bounds or incorectly used" | |
2081 ); | |
2082 | |
2083 uword i, j; | |
2084 // Store the length of values | |
2085 uword vlength = n_nonzero; | |
2086 // Store the length of col_ptrs | |
2087 uword clength = n_cols + 1; | |
2088 | |
2089 // This is O(n * n_cols) and inplace, there may be a faster way, though. | |
2090 for (i = 0, j = 0; i < vlength; ++i) | |
2091 { | |
2092 // Store the row of the ith element. | |
2093 const uword lrow = row_indices[i]; | |
2094 // Is the ith element in the range of rows we want to remove? | |
2095 if (lrow >= in_row1 && lrow <= in_row2) | |
2096 { | |
2097 // Increment our "removed elements" counter. | |
2098 ++j; | |
2099 | |
2100 // Adjust the values of col_ptrs each time we remove an element. | |
2101 // Basically, the length of one column reduces by one, and everything to | |
2102 // its right gets reduced by one to represent all the elements being | |
2103 // shifted to the left by one. | |
2104 for(uword k = 0; k < clength; ++k) | |
2105 { | |
2106 if (col_ptrs[k] > (i - j + 1)) | |
2107 { | |
2108 --access::rw(col_ptrs[k]); | |
2109 } | |
2110 } | |
2111 } | |
2112 else | |
2113 { | |
2114 // We shift the element we checked to the left by how many elements | |
2115 // we have removed. | |
2116 // j = 0 until we remove the first element. | |
2117 if (j != 0) | |
2118 { | |
2119 access::rw(row_indices[i - j]) = (lrow > in_row2) ? (lrow - (in_row2 - in_row1 + 1)) : lrow; | |
2120 access::rw(values[i - j]) = values[i]; | |
2121 } | |
2122 } | |
2123 } | |
2124 | |
2125 // j is the number of elements removed. | |
2126 | |
2127 // Shrink the vectors. This will copy the memory. | |
2128 mem_resize(n_nonzero - j); | |
2129 | |
2130 // Adjust row and element counts. | |
2131 access::rw(n_rows) = n_rows - (in_row2 - in_row1) - 1; | |
2132 access::rw(n_elem) = n_rows * n_cols; | |
2133 } | |
2134 | |
2135 /** | |
2136 * Remove all columns between (and including) in_col1 and in_col2. | |
2137 */ | |
2138 template<typename eT> | |
2139 inline | |
2140 void | |
2141 SpMat<eT>::shed_cols(const uword in_col1, const uword in_col2) | |
2142 { | |
2143 arma_extra_debug_sigprint(); | |
2144 | |
2145 arma_debug_check | |
2146 ( | |
2147 (in_col1 > in_col2) || (in_col2 >= n_cols), | |
2148 "SpMat::shed_cols(): indices out of bounds or incorrectly used" | |
2149 ); | |
2150 | |
2151 // First we find the locations in values and row_indices for the column entries. | |
2152 uword col_beg = col_ptrs[in_col1]; | |
2153 uword col_end = col_ptrs[in_col2 + 1]; | |
2154 | |
2155 // Then we find the number of entries in the column. | |
2156 uword diff = col_end - col_beg; | |
2157 | |
2158 if (diff > 0) | |
2159 { | |
2160 eT* new_values = memory::acquire_chunked<eT> (n_nonzero - diff); | |
2161 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero - diff); | |
2162 | |
2163 // Copy first part. | |
2164 if (col_beg != 0) | |
2165 { | |
2166 arrayops::copy(new_values, values, col_beg); | |
2167 arrayops::copy(new_row_indices, row_indices, col_beg); | |
2168 } | |
2169 | |
2170 // Copy second part. | |
2171 if (col_end != n_nonzero) | |
2172 { | |
2173 arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end); | |
2174 arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end); | |
2175 } | |
2176 | |
2177 memory::release(values); | |
2178 memory::release(row_indices); | |
2179 | |
2180 access::rw(values) = new_values; | |
2181 access::rw(row_indices) = new_row_indices; | |
2182 | |
2183 // Update counts and such. | |
2184 access::rw(n_nonzero) -= diff; | |
2185 } | |
2186 | |
2187 // Update column pointers. | |
2188 const uword new_n_cols = n_cols - ((in_col2 - in_col1) + 1); | |
2189 | |
2190 uword* new_col_ptrs = memory::acquire<uword>(new_n_cols + 2); | |
2191 new_col_ptrs[new_n_cols + 1] = std::numeric_limits<uword>::max(); | |
2192 | |
2193 // Copy first set of columns (no manipulation required). | |
2194 if (in_col1 != 0) | |
2195 { | |
2196 arrayops::copy(new_col_ptrs, col_ptrs, in_col1); | |
2197 } | |
2198 | |
2199 // Copy second set of columns (manipulation required). | |
2200 uword cur_col = in_col1; | |
2201 for (uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col) | |
2202 { | |
2203 new_col_ptrs[cur_col] = col_ptrs[i] - diff; | |
2204 } | |
2205 | |
2206 memory::release(col_ptrs); | |
2207 access::rw(col_ptrs) = new_col_ptrs; | |
2208 | |
2209 // We update the element and column counts, and we're done. | |
2210 access::rw(n_cols) = new_n_cols; | |
2211 access::rw(n_elem) = n_cols * n_rows; | |
2212 } | |
2213 | |
2214 | |
2215 | |
2216 template<typename eT> | |
2217 arma_inline | |
2218 SpSubview<eT> | |
2219 SpMat<eT>::rows(const uword in_row1, const uword in_row2) | |
2220 { | |
2221 arma_extra_debug_sigprint(); | |
2222 | |
2223 arma_debug_check | |
2224 ( | |
2225 (in_row1 > in_row2) || (in_row2 >= n_rows), | |
2226 "SpMat::rows(): indices out of bounds or incorrectly used" | |
2227 ); | |
2228 | |
2229 const uword subview_n_rows = in_row2 - in_row1 + 1; | |
2230 | |
2231 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols); | |
2232 } | |
2233 | |
2234 | |
2235 | |
2236 template<typename eT> | |
2237 arma_inline | |
2238 const SpSubview<eT> | |
2239 SpMat<eT>::rows(const uword in_row1, const uword in_row2) const | |
2240 { | |
2241 arma_extra_debug_sigprint(); | |
2242 | |
2243 arma_debug_check | |
2244 ( | |
2245 (in_row1 > in_row2) || (in_row2 >= n_rows), | |
2246 "SpMat::rows(): indices out of bounds or incorrectly used" | |
2247 ); | |
2248 | |
2249 const uword subview_n_rows = in_row2 - in_row1 + 1; | |
2250 | |
2251 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols); | |
2252 } | |
2253 | |
2254 | |
2255 | |
2256 template<typename eT> | |
2257 arma_inline | |
2258 SpSubview<eT> | |
2259 SpMat<eT>::cols(const uword in_col1, const uword in_col2) | |
2260 { | |
2261 arma_extra_debug_sigprint(); | |
2262 | |
2263 arma_debug_check | |
2264 ( | |
2265 (in_col1 > in_col2) || (in_col2 >= n_cols), | |
2266 "SpMat::cols(): indices out of bounds or incorrectly used" | |
2267 ); | |
2268 | |
2269 const uword subview_n_cols = in_col2 - in_col1 + 1; | |
2270 | |
2271 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols); | |
2272 } | |
2273 | |
2274 | |
2275 | |
2276 template<typename eT> | |
2277 arma_inline | |
2278 const SpSubview<eT> | |
2279 SpMat<eT>::cols(const uword in_col1, const uword in_col2) const | |
2280 { | |
2281 arma_extra_debug_sigprint(); | |
2282 | |
2283 arma_debug_check | |
2284 ( | |
2285 (in_col1 > in_col2) || (in_col2 >= n_cols), | |
2286 "SpMat::cols(): indices out of bounds or incorrectly used" | |
2287 ); | |
2288 | |
2289 const uword subview_n_cols = in_col2 - in_col1 + 1; | |
2290 | |
2291 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols); | |
2292 } | |
2293 | |
2294 | |
2295 | |
2296 template<typename eT> | |
2297 arma_inline | |
2298 SpSubview<eT> | |
2299 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) | |
2300 { | |
2301 arma_extra_debug_sigprint(); | |
2302 | |
2303 arma_debug_check | |
2304 ( | |
2305 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), | |
2306 "SpMat::submat(): indices out of bounds or incorrectly used" | |
2307 ); | |
2308 | |
2309 const uword subview_n_rows = in_row2 - in_row1 + 1; | |
2310 const uword subview_n_cols = in_col2 - in_col1 + 1; | |
2311 | |
2312 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); | |
2313 } | |
2314 | |
2315 | |
2316 | |
2317 template<typename eT> | |
2318 arma_inline | |
2319 const SpSubview<eT> | |
2320 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const | |
2321 { | |
2322 arma_extra_debug_sigprint(); | |
2323 | |
2324 arma_debug_check | |
2325 ( | |
2326 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), | |
2327 "SpMat::submat(): indices out of bounds or incorrectly used" | |
2328 ); | |
2329 | |
2330 const uword subview_n_rows = in_row2 - in_row1 + 1; | |
2331 const uword subview_n_cols = in_col2 - in_col1 + 1; | |
2332 | |
2333 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); | |
2334 } | |
2335 | |
2336 | |
2337 | |
2338 template<typename eT> | |
2339 inline | |
2340 SpSubview<eT> | |
2341 SpMat<eT>::submat (const span& row_span, const span& col_span) | |
2342 { | |
2343 arma_extra_debug_sigprint(); | |
2344 | |
2345 const bool row_all = row_span.whole; | |
2346 const bool col_all = col_span.whole; | |
2347 | |
2348 const uword local_n_rows = n_rows; | |
2349 const uword local_n_cols = n_cols; | |
2350 | |
2351 const uword in_row1 = row_all ? 0 : row_span.a; | |
2352 const uword in_row2 = row_span.b; | |
2353 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; | |
2354 | |
2355 const uword in_col1 = col_all ? 0 : col_span.a; | |
2356 const uword in_col2 = col_span.b; | |
2357 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; | |
2358 | |
2359 arma_debug_check | |
2360 ( | |
2361 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) | |
2362 || | |
2363 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) | |
2364 , | |
2365 "SpMat::submat(): indices out of bounds or incorrectly used" | |
2366 ); | |
2367 | |
2368 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); | |
2369 } | |
2370 | |
2371 | |
2372 | |
2373 template<typename eT> | |
2374 inline | |
2375 const SpSubview<eT> | |
2376 SpMat<eT>::submat (const span& row_span, const span& col_span) const | |
2377 { | |
2378 arma_extra_debug_sigprint(); | |
2379 | |
2380 const bool row_all = row_span.whole; | |
2381 const bool col_all = col_span.whole; | |
2382 | |
2383 const uword local_n_rows = n_rows; | |
2384 const uword local_n_cols = n_cols; | |
2385 | |
2386 const uword in_row1 = row_all ? 0 : row_span.a; | |
2387 const uword in_row2 = row_span.b; | |
2388 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; | |
2389 | |
2390 const uword in_col1 = col_all ? 0 : col_span.a; | |
2391 const uword in_col2 = col_span.b; | |
2392 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; | |
2393 | |
2394 arma_debug_check | |
2395 ( | |
2396 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) | |
2397 || | |
2398 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) | |
2399 , | |
2400 "SpMat::submat(): indices out of bounds or incorrectly used" | |
2401 ); | |
2402 | |
2403 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols); | |
2404 } | |
2405 | |
2406 | |
2407 | |
2408 template<typename eT> | |
2409 inline | |
2410 SpSubview<eT> | |
2411 SpMat<eT>::operator()(const span& row_span, const span& col_span) | |
2412 { | |
2413 arma_extra_debug_sigprint(); | |
2414 | |
2415 return submat(row_span, col_span); | |
2416 } | |
2417 | |
2418 | |
2419 | |
2420 template<typename eT> | |
2421 inline | |
2422 const SpSubview<eT> | |
2423 SpMat<eT>::operator()(const span& row_span, const span& col_span) const | |
2424 { | |
2425 arma_extra_debug_sigprint(); | |
2426 | |
2427 return submat(row_span, col_span); | |
2428 } | |
2429 | |
2430 | |
2431 | |
2432 /** | |
2433 * Element access; acces the i'th element (works identically to the Mat accessors). | |
2434 * If there is nothing at element i, 0 is returned. | |
2435 * | |
2436 * @param i Element to access. | |
2437 */ | |
2438 | |
2439 template<typename eT> | |
2440 arma_inline | |
2441 arma_warn_unused | |
2442 SpValProxy<SpMat<eT> > | |
2443 SpMat<eT>::operator[](const uword i) | |
2444 { | |
2445 return get_value(i); | |
2446 } | |
2447 | |
2448 | |
2449 | |
2450 template<typename eT> | |
2451 arma_inline | |
2452 arma_warn_unused | |
2453 eT | |
2454 SpMat<eT>::operator[](const uword i) const | |
2455 { | |
2456 return get_value(i); | |
2457 } | |
2458 | |
2459 | |
2460 | |
2461 template<typename eT> | |
2462 arma_inline | |
2463 arma_warn_unused | |
2464 SpValProxy<SpMat<eT> > | |
2465 SpMat<eT>::at(const uword i) | |
2466 { | |
2467 return get_value(i); | |
2468 } | |
2469 | |
2470 | |
2471 | |
2472 template<typename eT> | |
2473 arma_inline | |
2474 arma_warn_unused | |
2475 eT | |
2476 SpMat<eT>::at(const uword i) const | |
2477 { | |
2478 return get_value(i); | |
2479 } | |
2480 | |
2481 | |
2482 | |
2483 template<typename eT> | |
2484 arma_inline | |
2485 arma_warn_unused | |
2486 SpValProxy<SpMat<eT> > | |
2487 SpMat<eT>::operator()(const uword i) | |
2488 { | |
2489 arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds"); | |
2490 return get_value(i); | |
2491 } | |
2492 | |
2493 | |
2494 | |
2495 template<typename eT> | |
2496 arma_inline | |
2497 arma_warn_unused | |
2498 eT | |
2499 SpMat<eT>::operator()(const uword i) const | |
2500 { | |
2501 arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds"); | |
2502 return get_value(i); | |
2503 } | |
2504 | |
2505 | |
2506 | |
2507 /** | |
2508 * Element access; access the element at row in_rows and column in_col. | |
2509 * If there is nothing at that position, 0 is returned. | |
2510 */ | |
2511 | |
2512 template<typename eT> | |
2513 arma_inline | |
2514 arma_warn_unused | |
2515 SpValProxy<SpMat<eT> > | |
2516 SpMat<eT>::at(const uword in_row, const uword in_col) | |
2517 { | |
2518 return get_value(in_row, in_col); | |
2519 } | |
2520 | |
2521 | |
2522 | |
2523 template<typename eT> | |
2524 arma_inline | |
2525 arma_warn_unused | |
2526 eT | |
2527 SpMat<eT>::at(const uword in_row, const uword in_col) const | |
2528 { | |
2529 return get_value(in_row, in_col); | |
2530 } | |
2531 | |
2532 | |
2533 | |
2534 template<typename eT> | |
2535 arma_inline | |
2536 arma_warn_unused | |
2537 SpValProxy<SpMat<eT> > | |
2538 SpMat<eT>::operator()(const uword in_row, const uword in_col) | |
2539 { | |
2540 arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds"); | |
2541 return get_value(in_row, in_col); | |
2542 } | |
2543 | |
2544 | |
2545 | |
2546 template<typename eT> | |
2547 arma_inline | |
2548 arma_warn_unused | |
2549 eT | |
2550 SpMat<eT>::operator()(const uword in_row, const uword in_col) const | |
2551 { | |
2552 arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds"); | |
2553 return get_value(in_row, in_col); | |
2554 } | |
2555 | |
2556 | |
2557 | |
2558 /** | |
2559 * Check if matrix is empty (no size, no values). | |
2560 */ | |
2561 template<typename eT> | |
2562 arma_inline | |
2563 arma_warn_unused | |
2564 bool | |
2565 SpMat<eT>::is_empty() const | |
2566 { | |
2567 return(n_elem == 0); | |
2568 } | |
2569 | |
2570 | |
2571 | |
2572 //! returns true if the object can be interpreted as a column or row vector | |
2573 template<typename eT> | |
2574 arma_inline | |
2575 arma_warn_unused | |
2576 bool | |
2577 SpMat<eT>::is_vec() const | |
2578 { | |
2579 return ( (n_rows == 1) || (n_cols == 1) ); | |
2580 } | |
2581 | |
2582 | |
2583 | |
2584 //! returns true if the object can be interpreted as a row vector | |
2585 template<typename eT> | |
2586 arma_inline | |
2587 arma_warn_unused | |
2588 bool | |
2589 SpMat<eT>::is_rowvec() const | |
2590 { | |
2591 return (n_rows == 1); | |
2592 } | |
2593 | |
2594 | |
2595 | |
2596 //! returns true if the object can be interpreted as a column vector | |
2597 template<typename eT> | |
2598 arma_inline | |
2599 arma_warn_unused | |
2600 bool | |
2601 SpMat<eT>::is_colvec() const | |
2602 { | |
2603 return (n_cols == 1); | |
2604 } | |
2605 | |
2606 | |
2607 | |
2608 //! returns true if the object has the same number of non-zero rows and columnns | |
2609 template<typename eT> | |
2610 arma_inline | |
2611 arma_warn_unused | |
2612 bool | |
2613 SpMat<eT>::is_square() const | |
2614 { | |
2615 return (n_rows == n_cols); | |
2616 } | |
2617 | |
2618 | |
2619 | |
2620 //! returns true if all of the elements are finite | |
2621 template<typename eT> | |
2622 inline | |
2623 arma_warn_unused | |
2624 bool | |
2625 SpMat<eT>::is_finite() const | |
2626 { | |
2627 for(uword i = 0; i < n_nonzero; i++) | |
2628 { | |
2629 if(arma_isfinite(values[i]) == false) | |
2630 { | |
2631 return false; | |
2632 } | |
2633 } | |
2634 | |
2635 return true; // No infinite values. | |
2636 } | |
2637 | |
2638 | |
2639 | |
2640 //! returns true if the given index is currently in range | |
2641 template<typename eT> | |
2642 arma_inline | |
2643 arma_warn_unused | |
2644 bool | |
2645 SpMat<eT>::in_range(const uword i) const | |
2646 { | |
2647 return (i < n_elem); | |
2648 } | |
2649 | |
2650 | |
2651 //! returns true if the given start and end indices are currently in range | |
2652 template<typename eT> | |
2653 arma_inline | |
2654 arma_warn_unused | |
2655 bool | |
2656 SpMat<eT>::in_range(const span& x) const | |
2657 { | |
2658 arma_extra_debug_sigprint(); | |
2659 | |
2660 if(x.whole == true) | |
2661 { | |
2662 return true; | |
2663 } | |
2664 else | |
2665 { | |
2666 const uword a = x.a; | |
2667 const uword b = x.b; | |
2668 | |
2669 return ( (a <= b) && (b < n_elem) ); | |
2670 } | |
2671 } | |
2672 | |
2673 | |
2674 | |
2675 //! returns true if the given location is currently in range | |
2676 template<typename eT> | |
2677 arma_inline | |
2678 arma_warn_unused | |
2679 bool | |
2680 SpMat<eT>::in_range(const uword in_row, const uword in_col) const | |
2681 { | |
2682 return ( (in_row < n_rows) && (in_col < n_cols) ); | |
2683 } | |
2684 | |
2685 | |
2686 | |
2687 template<typename eT> | |
2688 arma_inline | |
2689 arma_warn_unused | |
2690 bool | |
2691 SpMat<eT>::in_range(const span& row_span, const uword in_col) const | |
2692 { | |
2693 arma_extra_debug_sigprint(); | |
2694 | |
2695 if(row_span.whole == true) | |
2696 { | |
2697 return (in_col < n_cols); | |
2698 } | |
2699 else | |
2700 { | |
2701 const uword in_row1 = row_span.a; | |
2702 const uword in_row2 = row_span.b; | |
2703 | |
2704 return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) ); | |
2705 } | |
2706 } | |
2707 | |
2708 | |
2709 | |
2710 template<typename eT> | |
2711 arma_inline | |
2712 arma_warn_unused | |
2713 bool | |
2714 SpMat<eT>::in_range(const uword in_row, const span& col_span) const | |
2715 { | |
2716 arma_extra_debug_sigprint(); | |
2717 | |
2718 if(col_span.whole == true) | |
2719 { | |
2720 return (in_row < n_rows); | |
2721 } | |
2722 else | |
2723 { | |
2724 const uword in_col1 = col_span.a; | |
2725 const uword in_col2 = col_span.b; | |
2726 | |
2727 return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) ); | |
2728 } | |
2729 } | |
2730 | |
2731 | |
2732 | |
2733 template<typename eT> | |
2734 arma_inline | |
2735 arma_warn_unused | |
2736 bool | |
2737 SpMat<eT>::in_range(const span& row_span, const span& col_span) const | |
2738 { | |
2739 arma_extra_debug_sigprint(); | |
2740 | |
2741 const uword in_row1 = row_span.a; | |
2742 const uword in_row2 = row_span.b; | |
2743 | |
2744 const uword in_col1 = col_span.a; | |
2745 const uword in_col2 = col_span.b; | |
2746 | |
2747 const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); | |
2748 const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); | |
2749 | |
2750 return ( (rows_ok == true) && (cols_ok == true) ); | |
2751 } | |
2752 | |
2753 | |
2754 | |
2755 template<typename eT> | |
2756 inline | |
2757 void | |
2758 SpMat<eT>::impl_print(const std::string& extra_text) const | |
2759 { | |
2760 arma_extra_debug_sigprint(); | |
2761 | |
2762 if(extra_text.length() != 0) | |
2763 { | |
2764 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width(); | |
2765 | |
2766 ARMA_DEFAULT_OSTREAM << extra_text << '\n'; | |
2767 | |
2768 ARMA_DEFAULT_OSTREAM.width(orig_width); | |
2769 } | |
2770 | |
2771 arma_ostream::print(ARMA_DEFAULT_OSTREAM, *this, true); | |
2772 } | |
2773 | |
2774 | |
2775 | |
2776 template<typename eT> | |
2777 inline | |
2778 void | |
2779 SpMat<eT>::impl_print(std::ostream& user_stream, const std::string& extra_text) const | |
2780 { | |
2781 arma_extra_debug_sigprint(); | |
2782 | |
2783 if(extra_text.length() != 0) | |
2784 { | |
2785 const std::streamsize orig_width = user_stream.width(); | |
2786 | |
2787 user_stream << extra_text << '\n'; | |
2788 | |
2789 user_stream.width(orig_width); | |
2790 } | |
2791 | |
2792 arma_ostream::print(user_stream, *this, true); | |
2793 } | |
2794 | |
2795 | |
2796 | |
2797 template<typename eT> | |
2798 inline | |
2799 void | |
2800 SpMat<eT>::impl_raw_print(const std::string& extra_text) const | |
2801 { | |
2802 arma_extra_debug_sigprint(); | |
2803 | |
2804 if(extra_text.length() != 0) | |
2805 { | |
2806 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width(); | |
2807 | |
2808 ARMA_DEFAULT_OSTREAM << extra_text << '\n'; | |
2809 | |
2810 ARMA_DEFAULT_OSTREAM.width(orig_width); | |
2811 } | |
2812 | |
2813 arma_ostream::print(ARMA_DEFAULT_OSTREAM, *this, false); | |
2814 } | |
2815 | |
2816 | |
2817 template<typename eT> | |
2818 inline | |
2819 void | |
2820 SpMat<eT>::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const | |
2821 { | |
2822 arma_extra_debug_sigprint(); | |
2823 | |
2824 if(extra_text.length() != 0) | |
2825 { | |
2826 const std::streamsize orig_width = user_stream.width(); | |
2827 | |
2828 user_stream << extra_text << '\n'; | |
2829 | |
2830 user_stream.width(orig_width); | |
2831 } | |
2832 | |
2833 arma_ostream::print(user_stream, *this, false); | |
2834 } | |
2835 | |
2836 | |
2837 | |
2838 /** | |
2839 * Matrix printing, prepends supplied text. | |
2840 * Prints 0 wherever no element exists. | |
2841 */ | |
2842 template<typename eT> | |
2843 inline | |
2844 void | |
2845 SpMat<eT>::impl_print_dense(const std::string& extra_text) const | |
2846 { | |
2847 arma_extra_debug_sigprint(); | |
2848 | |
2849 if(extra_text.length() != 0) | |
2850 { | |
2851 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width(); | |
2852 | |
2853 ARMA_DEFAULT_OSTREAM << extra_text << '\n'; | |
2854 | |
2855 ARMA_DEFAULT_OSTREAM.width(orig_width); | |
2856 } | |
2857 | |
2858 arma_ostream::print_dense(ARMA_DEFAULT_OSTREAM, *this, true); | |
2859 } | |
2860 | |
2861 | |
2862 | |
2863 template<typename eT> | |
2864 inline | |
2865 void | |
2866 SpMat<eT>::impl_print_dense(std::ostream& user_stream, const std::string& extra_text) const | |
2867 { | |
2868 arma_extra_debug_sigprint(); | |
2869 | |
2870 if(extra_text.length() != 0) | |
2871 { | |
2872 const std::streamsize orig_width = user_stream.width(); | |
2873 | |
2874 user_stream << extra_text << '\n'; | |
2875 | |
2876 user_stream.width(orig_width); | |
2877 } | |
2878 | |
2879 arma_ostream::print_dense(user_stream, *this, true); | |
2880 } | |
2881 | |
2882 | |
2883 | |
2884 template<typename eT> | |
2885 inline | |
2886 void | |
2887 SpMat<eT>::impl_raw_print_dense(const std::string& extra_text) const | |
2888 { | |
2889 arma_extra_debug_sigprint(); | |
2890 | |
2891 if(extra_text.length() != 0) | |
2892 { | |
2893 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width(); | |
2894 | |
2895 ARMA_DEFAULT_OSTREAM << extra_text << '\n'; | |
2896 | |
2897 ARMA_DEFAULT_OSTREAM.width(orig_width); | |
2898 } | |
2899 | |
2900 arma_ostream::print_dense(ARMA_DEFAULT_OSTREAM, *this, false); | |
2901 } | |
2902 | |
2903 | |
2904 | |
2905 template<typename eT> | |
2906 inline | |
2907 void | |
2908 SpMat<eT>::impl_raw_print_dense(std::ostream& user_stream, const std::string& extra_text) const | |
2909 { | |
2910 arma_extra_debug_sigprint(); | |
2911 | |
2912 if(extra_text.length() != 0) | |
2913 { | |
2914 const std::streamsize orig_width = user_stream.width(); | |
2915 | |
2916 user_stream << extra_text << '\n'; | |
2917 | |
2918 user_stream.width(orig_width); | |
2919 } | |
2920 | |
2921 arma_ostream::print_dense(user_stream, *this, false); | |
2922 } | |
2923 | |
2924 | |
2925 | |
2926 //! Set the size to the size of another matrix. | |
2927 template<typename eT> | |
2928 template<typename eT2> | |
2929 inline | |
2930 void | |
2931 SpMat<eT>::copy_size(const SpMat<eT2>& m) | |
2932 { | |
2933 arma_extra_debug_sigprint(); | |
2934 | |
2935 init(m.n_rows, m.n_cols); | |
2936 } | |
2937 | |
2938 | |
2939 | |
2940 template<typename eT> | |
2941 template<typename eT2> | |
2942 inline | |
2943 void | |
2944 SpMat<eT>::copy_size(const Mat<eT2>& m) | |
2945 { | |
2946 arma_extra_debug_sigprint(); | |
2947 | |
2948 init(m.n_rows, m.n_cols); | |
2949 } | |
2950 | |
2951 | |
2952 | |
2953 /** | |
2954 * Resize the matrix to a given size. The matrix will be resized to be a column vector (i.e. in_elem columns, 1 row). | |
2955 * | |
2956 * @param in_elem Number of elements to allow. | |
2957 */ | |
2958 template<typename eT> | |
2959 inline | |
2960 void | |
2961 SpMat<eT>::set_size(const uword in_elem) | |
2962 { | |
2963 arma_extra_debug_sigprint(); | |
2964 | |
2965 // If this is a row vector, we resize to a row vector. | |
2966 if(vec_state == 2) | |
2967 { | |
2968 init(1, in_elem); | |
2969 } | |
2970 else | |
2971 { | |
2972 init(in_elem, 1); | |
2973 } | |
2974 } | |
2975 | |
2976 | |
2977 | |
2978 /** | |
2979 * Resize the matrix to a given size. | |
2980 * | |
2981 * @param in_rows Number of rows to allow. | |
2982 * @param in_cols Number of columns to allow. | |
2983 */ | |
2984 template<typename eT> | |
2985 inline | |
2986 void | |
2987 SpMat<eT>::set_size(const uword in_rows, const uword in_cols) | |
2988 { | |
2989 arma_extra_debug_sigprint(); | |
2990 | |
2991 init(in_rows, in_cols); | |
2992 } | |
2993 | |
2994 | |
2995 | |
2996 template<typename eT> | |
2997 inline | |
2998 void | |
2999 SpMat<eT>::reshape(const uword in_rows, const uword in_cols, const uword dim) | |
3000 { | |
3001 arma_extra_debug_sigprint(); | |
3002 | |
3003 if (dim == 0) | |
3004 { | |
3005 // We have to modify all of the relevant row indices and the relevant column pointers. | |
3006 // Iterate over all the points to do this. We won't be deleting any points, but we will be modifying | |
3007 // columns and rows. We'll have to store a new set of column vectors. | |
3008 uword* new_col_ptrs = memory::acquire<uword>(in_cols + 2); | |
3009 new_col_ptrs[in_cols + 1] = std::numeric_limits<uword>::max(); | |
3010 | |
3011 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1); | |
3012 access::rw(new_row_indices[n_nonzero]) = 0; | |
3013 | |
3014 arrayops::inplace_set(new_col_ptrs, uword(0), in_cols + 1); | |
3015 | |
3016 for(const_iterator it = begin(); it != end(); it++) | |
3017 { | |
3018 uword vector_position = (it.col() * n_rows) + it.row(); | |
3019 new_row_indices[it.pos()] = vector_position % in_rows; | |
3020 ++new_col_ptrs[vector_position / in_rows + 1]; | |
3021 } | |
3022 | |
3023 // Now sum the column counts to get the new column pointers. | |
3024 for(uword i = 1; i <= in_cols; i++) | |
3025 { | |
3026 access::rw(new_col_ptrs[i]) += new_col_ptrs[i - 1]; | |
3027 } | |
3028 | |
3029 // Copy the new row indices. | |
3030 memory::release(row_indices); | |
3031 access::rw(row_indices) = new_row_indices; | |
3032 | |
3033 memory::release(col_ptrs); | |
3034 access::rw(col_ptrs) = new_col_ptrs; | |
3035 | |
3036 // Now set the size. | |
3037 access::rw(n_rows) = in_rows; | |
3038 access::rw(n_cols) = in_cols; | |
3039 } | |
3040 else | |
3041 { | |
3042 // Row-wise reshaping. This is more tedious and we will use a separate sparse matrix to do it. | |
3043 SpMat<eT> tmp(in_rows, in_cols); | |
3044 | |
3045 for(const_row_iterator it = begin_row(); it.pos() < n_nonzero; it++) | |
3046 { | |
3047 uword vector_position = (it.row() * n_cols) + it.col(); | |
3048 | |
3049 tmp((vector_position / in_cols), (vector_position % in_cols)) = (*it); | |
3050 } | |
3051 | |
3052 (*this).operator=(tmp); | |
3053 } | |
3054 } | |
3055 | |
3056 | |
3057 | |
3058 template<typename eT> | |
3059 inline | |
3060 const SpMat<eT>& | |
3061 SpMat<eT>::zeros() | |
3062 { | |
3063 arma_extra_debug_sigprint(); | |
3064 | |
3065 if (n_nonzero > 0) | |
3066 { | |
3067 memory::release(values); | |
3068 memory::release(row_indices); | |
3069 | |
3070 access::rw(values) = memory::acquire_chunked<eT>(1); | |
3071 access::rw(row_indices) = memory::acquire_chunked<uword>(1); | |
3072 | |
3073 access::rw(values[0]) = 0; | |
3074 access::rw(row_indices[0]) = 0; | |
3075 } | |
3076 | |
3077 access::rw(n_nonzero) = 0; | |
3078 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1); | |
3079 | |
3080 return *this; | |
3081 } | |
3082 | |
3083 | |
3084 | |
3085 template<typename eT> | |
3086 inline | |
3087 const SpMat<eT>& | |
3088 SpMat<eT>::zeros(const uword in_elem) | |
3089 { | |
3090 arma_extra_debug_sigprint(); | |
3091 | |
3092 if(vec_state == 2) | |
3093 { | |
3094 init(1, in_elem); // Row vector | |
3095 } | |
3096 else | |
3097 { | |
3098 init(in_elem, 1); | |
3099 } | |
3100 | |
3101 return *this; | |
3102 } | |
3103 | |
3104 | |
3105 | |
3106 template<typename eT> | |
3107 inline | |
3108 const SpMat<eT>& | |
3109 SpMat<eT>::zeros(const uword in_rows, const uword in_cols) | |
3110 { | |
3111 arma_extra_debug_sigprint(); | |
3112 | |
3113 init(in_rows, in_cols); | |
3114 | |
3115 return *this; | |
3116 } | |
3117 | |
3118 | |
3119 | |
3120 template<typename eT> | |
3121 inline | |
3122 const SpMat<eT>& | |
3123 SpMat<eT>::eye() | |
3124 { | |
3125 arma_extra_debug_sigprint(); | |
3126 | |
3127 return (*this).eye(n_rows, n_cols); | |
3128 } | |
3129 | |
3130 | |
3131 | |
3132 template<typename eT> | |
3133 inline | |
3134 const SpMat<eT>& | |
3135 SpMat<eT>::eye(const uword in_rows, const uword in_cols) | |
3136 { | |
3137 arma_extra_debug_sigprint(); | |
3138 | |
3139 const uword N = (std::min)(in_rows, in_cols); | |
3140 | |
3141 init(in_rows, in_cols); | |
3142 | |
3143 mem_resize(N); | |
3144 | |
3145 arrayops::inplace_set(access::rwp(values), eT(1), N); | |
3146 | |
3147 for(uword i = 0; i < N; ++i) { access::rw(row_indices[i]) = i; } | |
3148 | |
3149 for(uword i = 0; i <= N; ++i) { access::rw(col_ptrs[i]) = i; } | |
3150 | |
3151 access::rw(n_nonzero) = N; | |
3152 | |
3153 return *this; | |
3154 } | |
3155 | |
3156 | |
3157 | |
3158 template<typename eT> | |
3159 inline | |
3160 const SpMat<eT>& | |
3161 SpMat<eT>::speye() | |
3162 { | |
3163 arma_extra_debug_sigprint(); | |
3164 | |
3165 return (*this).eye(n_rows, n_cols); | |
3166 } | |
3167 | |
3168 | |
3169 | |
3170 template<typename eT> | |
3171 inline | |
3172 const SpMat<eT>& | |
3173 SpMat<eT>::speye(const uword in_n_rows, const uword in_n_cols) | |
3174 { | |
3175 arma_extra_debug_sigprint(); | |
3176 | |
3177 return (*this).eye(in_n_rows, in_n_cols); | |
3178 } | |
3179 | |
3180 | |
3181 | |
3182 template<typename eT> | |
3183 inline | |
3184 const SpMat<eT>& | |
3185 SpMat<eT>::sprandu(const uword in_rows, const uword in_cols, const double density) | |
3186 { | |
3187 arma_extra_debug_sigprint(); | |
3188 | |
3189 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandu(): density must be in the [0,1] interval" ); | |
3190 | |
3191 zeros(in_rows, in_cols); | |
3192 | |
3193 mem_resize( uword(density * double(in_rows) * double(in_cols) + 0.5) ); | |
3194 | |
3195 if(n_nonzero == 0) | |
3196 { | |
3197 return *this; | |
3198 } | |
3199 | |
3200 eop_aux_randu<eT>::fill( access::rwp(values), n_nonzero ); | |
3201 | |
3202 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, n_nonzero ); | |
3203 | |
3204 // perturb the indices | |
3205 for(uword i=1; i < n_nonzero-1; ++i) | |
3206 { | |
3207 const uword index_left = indices[i-1]; | |
3208 const uword index_right = indices[i+1]; | |
3209 | |
3210 const uword center = (index_left + index_right) / 2; | |
3211 | |
3212 const uword delta1 = center - index_left - 1; | |
3213 const uword delta2 = index_right - center - 1; | |
3214 | |
3215 const uword min_delta = (std::min)(delta1, delta2); | |
3216 | |
3217 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) ); | |
3218 | |
3219 // paranoia, but better be safe than sorry | |
3220 if( (index_left < index_new) && (index_new < index_right) ) | |
3221 { | |
3222 indices[i] = index_new; | |
3223 } | |
3224 } | |
3225 | |
3226 uword cur_index = 0; | |
3227 uword count = 0; | |
3228 | |
3229 for(uword lcol = 0; lcol < in_cols; ++lcol) | |
3230 for(uword lrow = 0; lrow < in_rows; ++lrow) | |
3231 { | |
3232 if(count == indices[cur_index]) | |
3233 { | |
3234 access::rw(row_indices[cur_index]) = lrow; | |
3235 access::rw(col_ptrs[lcol + 1])++; | |
3236 ++cur_index; | |
3237 } | |
3238 | |
3239 ++count; | |
3240 } | |
3241 | |
3242 if(cur_index != n_nonzero) | |
3243 { | |
3244 // Fix size to correct size. | |
3245 mem_resize(cur_index); | |
3246 } | |
3247 | |
3248 // Sum column pointers. | |
3249 for(uword lcol = 1; lcol <= in_cols; ++lcol) | |
3250 { | |
3251 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1]; | |
3252 } | |
3253 | |
3254 return *this; | |
3255 } | |
3256 | |
3257 | |
3258 | |
3259 template<typename eT> | |
3260 inline | |
3261 const SpMat<eT>& | |
3262 SpMat<eT>::sprandn(const uword in_rows, const uword in_cols, const double density) | |
3263 { | |
3264 arma_extra_debug_sigprint(); | |
3265 | |
3266 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandn(): density must be in the [0,1] interval" ); | |
3267 | |
3268 zeros(in_rows, in_cols); | |
3269 | |
3270 mem_resize( uword(density * double(in_rows) * double(in_cols) + 0.5) ); | |
3271 | |
3272 if(n_nonzero == 0) | |
3273 { | |
3274 return *this; | |
3275 } | |
3276 | |
3277 eop_aux_randn<eT>::fill( access::rwp(values), n_nonzero ); | |
3278 | |
3279 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, n_nonzero ); | |
3280 | |
3281 // perturb the indices | |
3282 for(uword i=1; i < n_nonzero-1; ++i) | |
3283 { | |
3284 const uword index_left = indices[i-1]; | |
3285 const uword index_right = indices[i+1]; | |
3286 | |
3287 const uword center = (index_left + index_right) / 2; | |
3288 | |
3289 const uword delta1 = center - index_left - 1; | |
3290 const uword delta2 = index_right - center - 1; | |
3291 | |
3292 const uword min_delta = (std::min)(delta1, delta2); | |
3293 | |
3294 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) ); | |
3295 | |
3296 // paranoia, but better be safe than sorry | |
3297 if( (index_left < index_new) && (index_new < index_right) ) | |
3298 { | |
3299 indices[i] = index_new; | |
3300 } | |
3301 } | |
3302 | |
3303 uword cur_index = 0; | |
3304 uword count = 0; | |
3305 | |
3306 for(uword lcol = 0; lcol < in_cols; ++lcol) | |
3307 for(uword lrow = 0; lrow < in_rows; ++lrow) | |
3308 { | |
3309 if(count == indices[cur_index]) | |
3310 { | |
3311 access::rw(row_indices[cur_index]) = lrow; | |
3312 access::rw(col_ptrs[lcol + 1])++; | |
3313 ++cur_index; | |
3314 } | |
3315 | |
3316 ++count; | |
3317 } | |
3318 | |
3319 if(cur_index != n_nonzero) | |
3320 { | |
3321 // Fix size to correct size. | |
3322 mem_resize(cur_index); | |
3323 } | |
3324 | |
3325 // Sum column pointers. | |
3326 for(uword lcol = 1; lcol <= in_cols; ++lcol) | |
3327 { | |
3328 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1]; | |
3329 } | |
3330 | |
3331 return *this; | |
3332 } | |
3333 | |
3334 | |
3335 | |
3336 template<typename eT> | |
3337 inline | |
3338 void | |
3339 SpMat<eT>::reset() | |
3340 { | |
3341 arma_extra_debug_sigprint(); | |
3342 | |
3343 set_size(0, 0); | |
3344 } | |
3345 | |
3346 | |
3347 | |
3348 /** | |
3349 * Get the minimum or the maximum of the matrix. | |
3350 */ | |
3351 template<typename eT> | |
3352 inline | |
3353 arma_warn_unused | |
3354 eT | |
3355 SpMat<eT>::min() const | |
3356 { | |
3357 arma_extra_debug_sigprint(); | |
3358 | |
3359 arma_debug_check((n_elem == 0), "min(): object has no elements"); | |
3360 | |
3361 if (n_nonzero == 0) | |
3362 { | |
3363 return 0; | |
3364 } | |
3365 | |
3366 eT val = op_min::direct_min(values, n_nonzero); | |
3367 | |
3368 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less. | |
3369 { | |
3370 val = 0; | |
3371 } | |
3372 | |
3373 return val; | |
3374 } | |
3375 | |
3376 | |
3377 | |
3378 template<typename eT> | |
3379 inline | |
3380 eT | |
3381 SpMat<eT>::min(uword& index_of_min_val) const | |
3382 { | |
3383 arma_extra_debug_sigprint(); | |
3384 | |
3385 arma_debug_check((n_elem == 0), "min(): object has no elements"); | |
3386 | |
3387 eT val = 0; | |
3388 | |
3389 if (n_nonzero == 0) // There are no other elements. It must be 0. | |
3390 { | |
3391 index_of_min_val = 0; | |
3392 } | |
3393 else | |
3394 { | |
3395 uword location; | |
3396 val = op_min::direct_min(values, n_nonzero, location); | |
3397 | |
3398 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less. | |
3399 { | |
3400 val = 0; | |
3401 | |
3402 // Give back the index to the first zero position. | |
3403 index_of_min_val = 0; | |
3404 while (get_position(index_of_min_val) == index_of_min_val) // An element exists at that position. | |
3405 { | |
3406 index_of_min_val++; | |
3407 } | |
3408 | |
3409 } | |
3410 else | |
3411 { | |
3412 index_of_min_val = get_position(location); | |
3413 } | |
3414 } | |
3415 | |
3416 return val; | |
3417 | |
3418 } | |
3419 | |
3420 | |
3421 | |
3422 template<typename eT> | |
3423 inline | |
3424 eT | |
3425 SpMat<eT>::min(uword& row_of_min_val, uword& col_of_min_val) const | |
3426 { | |
3427 arma_extra_debug_sigprint(); | |
3428 | |
3429 arma_debug_check((n_elem == 0), "min(): object has no elements"); | |
3430 | |
3431 eT val = 0; | |
3432 | |
3433 if (n_nonzero == 0) // There are no other elements. It must be 0. | |
3434 { | |
3435 row_of_min_val = 0; | |
3436 col_of_min_val = 0; | |
3437 } | |
3438 else | |
3439 { | |
3440 uword location; | |
3441 val = op_min::direct_min(values, n_nonzero, location); | |
3442 | |
3443 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less. | |
3444 { | |
3445 val = 0; | |
3446 | |
3447 location = 0; | |
3448 while (get_position(location) == location) // An element exists at that position. | |
3449 { | |
3450 location++; | |
3451 } | |
3452 | |
3453 row_of_min_val = location % n_rows; | |
3454 col_of_min_val = location / n_rows; | |
3455 } | |
3456 else | |
3457 { | |
3458 get_position(location, row_of_min_val, col_of_min_val); | |
3459 } | |
3460 } | |
3461 | |
3462 return val; | |
3463 | |
3464 } | |
3465 | |
3466 | |
3467 | |
3468 template<typename eT> | |
3469 inline | |
3470 arma_warn_unused | |
3471 eT | |
3472 SpMat<eT>::max() const | |
3473 { | |
3474 arma_extra_debug_sigprint(); | |
3475 | |
3476 arma_debug_check((n_elem == 0), "max(): object has no elements"); | |
3477 | |
3478 if (n_nonzero == 0) | |
3479 { | |
3480 return 0; | |
3481 } | |
3482 | |
3483 eT val = op_max::direct_max(values, n_nonzero); | |
3484 | |
3485 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more. | |
3486 { | |
3487 return 0; | |
3488 } | |
3489 | |
3490 return val; | |
3491 | |
3492 } | |
3493 | |
3494 | |
3495 | |
3496 template<typename eT> | |
3497 inline | |
3498 eT | |
3499 SpMat<eT>::max(uword& index_of_max_val) const | |
3500 { | |
3501 arma_extra_debug_sigprint(); | |
3502 | |
3503 arma_debug_check((n_elem == 0), "max(): object has no elements"); | |
3504 | |
3505 eT val = 0; | |
3506 | |
3507 if (n_nonzero == 0) | |
3508 { | |
3509 index_of_max_val = 0; | |
3510 } | |
3511 else | |
3512 { | |
3513 uword location; | |
3514 val = op_max::direct_max(values, n_nonzero, location); | |
3515 | |
3516 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more. | |
3517 { | |
3518 val = 0; | |
3519 | |
3520 location = 0; | |
3521 while (get_position(location) == location) // An element exists at that position. | |
3522 { | |
3523 location++; | |
3524 } | |
3525 | |
3526 } | |
3527 else | |
3528 { | |
3529 index_of_max_val = get_position(location); | |
3530 } | |
3531 | |
3532 } | |
3533 | |
3534 return val; | |
3535 | |
3536 } | |
3537 | |
3538 | |
3539 | |
3540 template<typename eT> | |
3541 inline | |
3542 eT | |
3543 SpMat<eT>::max(uword& row_of_max_val, uword& col_of_max_val) const | |
3544 { | |
3545 arma_extra_debug_sigprint(); | |
3546 | |
3547 arma_debug_check((n_elem == 0), "max(): object has no elements"); | |
3548 | |
3549 eT val = 0; | |
3550 | |
3551 if (n_nonzero == 0) | |
3552 { | |
3553 row_of_max_val = 0; | |
3554 col_of_max_val = 0; | |
3555 } | |
3556 else | |
3557 { | |
3558 uword location; | |
3559 val = op_max::direct_max(values, n_nonzero, location); | |
3560 | |
3561 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more. | |
3562 { | |
3563 val = 0; | |
3564 | |
3565 location = 0; | |
3566 while (get_position(location) == location) // An element exists at that position. | |
3567 { | |
3568 location++; | |
3569 } | |
3570 | |
3571 row_of_max_val = location % n_rows; | |
3572 col_of_max_val = location / n_rows; | |
3573 | |
3574 } | |
3575 else | |
3576 { | |
3577 get_position(location, row_of_max_val, col_of_max_val); | |
3578 } | |
3579 | |
3580 } | |
3581 | |
3582 return val; | |
3583 | |
3584 } | |
3585 | |
3586 | |
3587 | |
3588 //! save the matrix to a file | |
3589 template<typename eT> | |
3590 inline | |
3591 bool | |
3592 SpMat<eT>::save(const std::string name, const file_type type, const bool print_status) const | |
3593 { | |
3594 arma_extra_debug_sigprint(); | |
3595 | |
3596 bool save_okay; | |
3597 | |
3598 switch(type) | |
3599 { | |
3600 // case raw_ascii: | |
3601 // save_okay = diskio::save_raw_ascii(*this, name); | |
3602 // break; | |
3603 | |
3604 // case csv_ascii: | |
3605 // save_okay = diskio::save_csv_ascii(*this, name); | |
3606 // break; | |
3607 | |
3608 case arma_binary: | |
3609 save_okay = diskio::save_arma_binary(*this, name); | |
3610 break; | |
3611 | |
3612 case coord_ascii: | |
3613 save_okay = diskio::save_coord_ascii(*this, name); | |
3614 break; | |
3615 | |
3616 default: | |
3617 arma_warn(true, "SpMat::save(): unsupported file type"); | |
3618 save_okay = false; | |
3619 } | |
3620 | |
3621 arma_warn( (save_okay == false), "SpMat::save(): couldn't write to ", name); | |
3622 | |
3623 return save_okay; | |
3624 } | |
3625 | |
3626 | |
3627 | |
3628 //! save the matrix to a stream | |
3629 template<typename eT> | |
3630 inline | |
3631 bool | |
3632 SpMat<eT>::save(std::ostream& os, const file_type type, const bool print_status) const | |
3633 { | |
3634 arma_extra_debug_sigprint(); | |
3635 | |
3636 bool save_okay; | |
3637 | |
3638 switch(type) | |
3639 { | |
3640 // case raw_ascii: | |
3641 // save_okay = diskio::save_raw_ascii(*this, os); | |
3642 // break; | |
3643 | |
3644 // case csv_ascii: | |
3645 // save_okay = diskio::save_csv_ascii(*this, os); | |
3646 // break; | |
3647 | |
3648 case arma_binary: | |
3649 save_okay = diskio::save_arma_binary(*this, os); | |
3650 break; | |
3651 | |
3652 case coord_ascii: | |
3653 save_okay = diskio::save_coord_ascii(*this, os); | |
3654 break; | |
3655 | |
3656 default: | |
3657 arma_warn(true, "SpMat::save(): unsupported file type"); | |
3658 save_okay = false; | |
3659 } | |
3660 | |
3661 arma_warn( (save_okay == false), "SpMat::save(): couldn't write to the given stream"); | |
3662 | |
3663 return save_okay; | |
3664 } | |
3665 | |
3666 | |
3667 | |
3668 //! load a matrix from a file | |
3669 template<typename eT> | |
3670 inline | |
3671 bool | |
3672 SpMat<eT>::load(const std::string name, const file_type type, const bool print_status) | |
3673 { | |
3674 arma_extra_debug_sigprint(); | |
3675 | |
3676 bool load_okay; | |
3677 std::string err_msg; | |
3678 | |
3679 switch(type) | |
3680 { | |
3681 // case auto_detect: | |
3682 // load_okay = diskio::load_auto_detect(*this, name, err_msg); | |
3683 // break; | |
3684 | |
3685 // case raw_ascii: | |
3686 // load_okay = diskio::load_raw_ascii(*this, name, err_msg); | |
3687 // break; | |
3688 | |
3689 // case csv_ascii: | |
3690 // load_okay = diskio::load_csv_ascii(*this, name, err_msg); | |
3691 // break; | |
3692 | |
3693 case arma_binary: | |
3694 load_okay = diskio::load_arma_binary(*this, name, err_msg); | |
3695 break; | |
3696 | |
3697 case coord_ascii: | |
3698 load_okay = diskio::load_coord_ascii(*this, name, err_msg); | |
3699 break; | |
3700 | |
3701 default: | |
3702 arma_warn(true, "SpMat::load(): unsupported file type"); | |
3703 load_okay = false; | |
3704 } | |
3705 | |
3706 if(load_okay == false) | |
3707 { | |
3708 if(err_msg.length() > 0) | |
3709 { | |
3710 arma_warn(true, "SpMat::load(): ", err_msg, name); | |
3711 } | |
3712 else | |
3713 { | |
3714 arma_warn(true, "SpMat::load(): couldn't read ", name); | |
3715 } | |
3716 } | |
3717 | |
3718 if(load_okay == false) | |
3719 { | |
3720 (*this).reset(); | |
3721 } | |
3722 | |
3723 return load_okay; | |
3724 } | |
3725 | |
3726 | |
3727 | |
3728 //! load a matrix from a stream | |
3729 template<typename eT> | |
3730 inline | |
3731 bool | |
3732 SpMat<eT>::load(std::istream& is, const file_type type, const bool print_status) | |
3733 { | |
3734 arma_extra_debug_sigprint(); | |
3735 | |
3736 bool load_okay; | |
3737 std::string err_msg; | |
3738 | |
3739 switch(type) | |
3740 { | |
3741 // case auto_detect: | |
3742 // load_okay = diskio::load_auto_detect(*this, is, err_msg); | |
3743 // break; | |
3744 | |
3745 // case raw_ascii: | |
3746 // load_okay = diskio::load_raw_ascii(*this, is, err_msg); | |
3747 // break; | |
3748 | |
3749 // case csv_ascii: | |
3750 // load_okay = diskio::load_csv_ascii(*this, is, err_msg); | |
3751 // break; | |
3752 | |
3753 case arma_binary: | |
3754 load_okay = diskio::load_arma_binary(*this, is, err_msg); | |
3755 break; | |
3756 | |
3757 case coord_ascii: | |
3758 load_okay = diskio::load_coord_ascii(*this, is, err_msg); | |
3759 break; | |
3760 | |
3761 default: | |
3762 arma_warn(true, "SpMat::load(): unsupported file type"); | |
3763 load_okay = false; | |
3764 } | |
3765 | |
3766 | |
3767 if(load_okay == false) | |
3768 { | |
3769 if(err_msg.length() > 0) | |
3770 { | |
3771 arma_warn(true, "SpMat::load(): ", err_msg, "the given stream"); | |
3772 } | |
3773 else | |
3774 { | |
3775 arma_warn(true, "SpMat::load(): couldn't load from the given stream"); | |
3776 } | |
3777 } | |
3778 | |
3779 if(load_okay == false) | |
3780 { | |
3781 (*this).reset(); | |
3782 } | |
3783 | |
3784 return load_okay; | |
3785 } | |
3786 | |
3787 | |
3788 | |
3789 //! save the matrix to a file, without printing any error messages | |
3790 template<typename eT> | |
3791 inline | |
3792 bool | |
3793 SpMat<eT>::quiet_save(const std::string name, const file_type type) const | |
3794 { | |
3795 arma_extra_debug_sigprint(); | |
3796 | |
3797 return (*this).save(name, type, false); | |
3798 } | |
3799 | |
3800 | |
3801 | |
3802 //! save the matrix to a stream, without printing any error messages | |
3803 template<typename eT> | |
3804 inline | |
3805 bool | |
3806 SpMat<eT>::quiet_save(std::ostream& os, const file_type type) const | |
3807 { | |
3808 arma_extra_debug_sigprint(); | |
3809 | |
3810 return (*this).save(os, type, false); | |
3811 } | |
3812 | |
3813 | |
3814 | |
3815 //! load a matrix from a file, without printing any error messages | |
3816 template<typename eT> | |
3817 inline | |
3818 bool | |
3819 SpMat<eT>::quiet_load(const std::string name, const file_type type) | |
3820 { | |
3821 arma_extra_debug_sigprint(); | |
3822 | |
3823 return (*this).load(name, type, false); | |
3824 } | |
3825 | |
3826 | |
3827 | |
3828 //! load a matrix from a stream, without printing any error messages | |
3829 template<typename eT> | |
3830 inline | |
3831 bool | |
3832 SpMat<eT>::quiet_load(std::istream& is, const file_type type) | |
3833 { | |
3834 arma_extra_debug_sigprint(); | |
3835 | |
3836 return (*this).load(is, type, false); | |
3837 } | |
3838 | |
3839 | |
3840 | |
3841 /** | |
3842 * Initialize the matrix to the specified size. Data is not preserved, so the matrix is assumed to be entirely sparse (empty). | |
3843 */ | |
3844 template<typename eT> | |
3845 inline | |
3846 void | |
3847 SpMat<eT>::init(uword in_rows, uword in_cols) | |
3848 { | |
3849 arma_extra_debug_sigprint(); | |
3850 | |
3851 // Verify that we are allowed to do this. | |
3852 if(vec_state > 0) | |
3853 { | |
3854 if((in_rows == 0) && (in_cols == 0)) | |
3855 { | |
3856 if(vec_state == 1) | |
3857 { | |
3858 in_cols = 1; | |
3859 } | |
3860 else | |
3861 if(vec_state == 2) | |
3862 { | |
3863 in_rows = 1; | |
3864 } | |
3865 } | |
3866 else | |
3867 { | |
3868 arma_debug_check | |
3869 ( | |
3870 ( ((vec_state == 1) && (in_cols != 1)) || ((vec_state == 2) && (in_rows != 1)) ), | |
3871 "SpMat::init(): object is a row or column vector; requested size is not compatible" | |
3872 ); | |
3873 } | |
3874 } | |
3875 | |
3876 // Ensure that n_elem can hold the result of (n_rows * n_cols) | |
3877 arma_debug_check | |
3878 ( | |
3879 ( | |
3880 ( (in_rows > ARMA_MAX_UHWORD) || (in_cols > ARMA_MAX_UHWORD) ) | |
3881 ? ( (float(in_rows) * float(in_cols)) > float(ARMA_MAX_UWORD) ) | |
3882 : false | |
3883 ), | |
3884 "SpMat::init(): requested size is too large" | |
3885 ); | |
3886 | |
3887 // Clean out the existing memory. | |
3888 if (values) | |
3889 { | |
3890 memory::release(values); | |
3891 memory::release(row_indices); | |
3892 } | |
3893 | |
3894 access::rw(values) = memory::acquire_chunked<eT> (1); | |
3895 access::rw(row_indices) = memory::acquire_chunked<uword>(1); | |
3896 | |
3897 access::rw(values[0]) = 0; | |
3898 access::rw(row_indices[0]) = 0; | |
3899 | |
3900 memory::release(col_ptrs); | |
3901 | |
3902 // Set the new size accordingly. | |
3903 access::rw(n_rows) = in_rows; | |
3904 access::rw(n_cols) = in_cols; | |
3905 access::rw(n_elem) = (in_rows * in_cols); | |
3906 access::rw(n_nonzero) = 0; | |
3907 | |
3908 // Try to allocate the column pointers, filling them with 0, except for the | |
3909 // last element which contains the maximum possible element (so iterators | |
3910 // terminate correctly). | |
3911 access::rw(col_ptrs) = memory::acquire<uword>(in_cols + 2); | |
3912 access::rw(col_ptrs[in_cols + 1]) = std::numeric_limits<uword>::max(); | |
3913 | |
3914 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), in_cols + 1); | |
3915 } | |
3916 | |
3917 | |
3918 | |
3919 /** | |
3920 * Initialize the matrix from a string. | |
3921 */ | |
3922 template<typename eT> | |
3923 inline | |
3924 void | |
3925 SpMat<eT>::init(const std::string& text) | |
3926 { | |
3927 arma_extra_debug_sigprint(); | |
3928 | |
3929 // Figure out the size first. | |
3930 uword t_n_rows = 0; | |
3931 uword t_n_cols = 0; | |
3932 | |
3933 bool t_n_cols_found = false; | |
3934 | |
3935 std::string token; | |
3936 | |
3937 std::string::size_type line_start = 0; | |
3938 std::string::size_type line_end = 0; | |
3939 | |
3940 while (line_start < text.length()) | |
3941 { | |
3942 | |
3943 line_end = text.find(';', line_start); | |
3944 | |
3945 if (line_end == std::string::npos) | |
3946 line_end = text.length() - 1; | |
3947 | |
3948 std::string::size_type line_len = line_end - line_start + 1; | |
3949 std::stringstream line_stream(text.substr(line_start, line_len)); | |
3950 | |
3951 // Step through each column. | |
3952 uword line_n_cols = 0; | |
3953 | |
3954 while (line_stream >> token) | |
3955 { | |
3956 ++line_n_cols; | |
3957 } | |
3958 | |
3959 if (line_n_cols > 0) | |
3960 { | |
3961 if (t_n_cols_found == false) | |
3962 { | |
3963 t_n_cols = line_n_cols; | |
3964 t_n_cols_found = true; | |
3965 } | |
3966 else // Check it each time through, just to make sure. | |
3967 arma_check((line_n_cols != t_n_cols), "SpMat::init(): inconsistent number of columns in given string"); | |
3968 | |
3969 ++t_n_rows; | |
3970 } | |
3971 | |
3972 line_start = line_end + 1; | |
3973 | |
3974 } | |
3975 | |
3976 set_size(t_n_rows, t_n_cols); | |
3977 | |
3978 // Second time through will pick up all the values. | |
3979 line_start = 0; | |
3980 line_end = 0; | |
3981 | |
3982 uword lrow = 0; | |
3983 | |
3984 while (line_start < text.length()) | |
3985 { | |
3986 | |
3987 line_end = text.find(';', line_start); | |
3988 | |
3989 if (line_end == std::string::npos) | |
3990 line_end = text.length() - 1; | |
3991 | |
3992 std::string::size_type line_len = line_end - line_start + 1; | |
3993 std::stringstream line_stream(text.substr(line_start, line_len)); | |
3994 | |
3995 uword lcol = 0; | |
3996 eT val; | |
3997 | |
3998 while (line_stream >> val) | |
3999 { | |
4000 // Only add nonzero elements. | |
4001 if (val != eT(0)) | |
4002 { | |
4003 get_value(lrow, lcol) = val; | |
4004 } | |
4005 | |
4006 ++lcol; | |
4007 } | |
4008 | |
4009 ++lrow; | |
4010 line_start = line_end + 1; | |
4011 | |
4012 } | |
4013 | |
4014 } | |
4015 | |
4016 /** | |
4017 * Copy from another matrix. | |
4018 */ | |
4019 template<typename eT> | |
4020 inline | |
4021 void | |
4022 SpMat<eT>::init(const SpMat<eT>& x) | |
4023 { | |
4024 arma_extra_debug_sigprint(); | |
4025 | |
4026 // Ensure we are not initializing to ourselves. | |
4027 if (this != &x) | |
4028 { | |
4029 init(x.n_rows, x.n_cols); | |
4030 | |
4031 // values and row_indices may not be null. | |
4032 if (values != NULL) | |
4033 { | |
4034 memory::release(values); | |
4035 memory::release(row_indices); | |
4036 } | |
4037 | |
4038 access::rw(values) = memory::acquire_chunked<eT> (x.n_nonzero + 1); | |
4039 access::rw(row_indices) = memory::acquire_chunked<uword>(x.n_nonzero + 1); | |
4040 | |
4041 // Now copy over the elements. | |
4042 arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); | |
4043 arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); | |
4044 arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); | |
4045 | |
4046 access::rw(n_nonzero) = x.n_nonzero; | |
4047 } | |
4048 } | |
4049 | |
4050 | |
4051 | |
4052 template<typename eT> | |
4053 inline | |
4054 void | |
4055 SpMat<eT>::mem_resize(const uword new_n_nonzero) | |
4056 { | |
4057 arma_extra_debug_sigprint(); | |
4058 | |
4059 if(n_nonzero != new_n_nonzero) | |
4060 { | |
4061 if(new_n_nonzero == 0) | |
4062 { | |
4063 memory::release(values); | |
4064 memory::release(row_indices); | |
4065 | |
4066 access::rw(values) = memory::acquire_chunked<eT> (1); | |
4067 access::rw(row_indices) = memory::acquire_chunked<uword>(1); | |
4068 | |
4069 access::rw(values[0]) = 0; | |
4070 access::rw(row_indices[0]) = 0; | |
4071 } | |
4072 else | |
4073 { | |
4074 // Figure out the actual amount of memory currently allocated | |
4075 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays | |
4076 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(n_nonzero); | |
4077 | |
4078 if(n_alloc < new_n_nonzero) | |
4079 { | |
4080 eT* new_values = memory::acquire_chunked<eT> (new_n_nonzero + 1); | |
4081 uword* new_row_indices = memory::acquire_chunked<uword>(new_n_nonzero + 1); | |
4082 | |
4083 if(n_nonzero > 0) | |
4084 { | |
4085 // Copy old elements. | |
4086 uword copy_len = std::min(n_nonzero, new_n_nonzero); | |
4087 | |
4088 arrayops::copy(new_values, values, copy_len); | |
4089 arrayops::copy(new_row_indices, row_indices, copy_len); | |
4090 } | |
4091 | |
4092 memory::release(values); | |
4093 memory::release(row_indices); | |
4094 | |
4095 access::rw(values) = new_values; | |
4096 access::rw(row_indices) = new_row_indices; | |
4097 } | |
4098 | |
4099 // Set the "fake end" of the matrix by setting the last value and row | |
4100 // index to 0. This helps the iterators work correctly. | |
4101 access::rw(values[new_n_nonzero]) = 0; | |
4102 access::rw(row_indices[new_n_nonzero]) = 0; | |
4103 } | |
4104 | |
4105 access::rw(n_nonzero) = new_n_nonzero; | |
4106 } | |
4107 } | |
4108 | |
4109 | |
4110 | |
4111 // Steal memory from another matrix. | |
4112 template<typename eT> | |
4113 inline | |
4114 void | |
4115 SpMat<eT>::steal_mem(SpMat<eT>& x) | |
4116 { | |
4117 arma_extra_debug_sigprint(); | |
4118 | |
4119 if(this != &x) | |
4120 { | |
4121 // Release all the memory. | |
4122 memory::release(values); | |
4123 memory::release(row_indices); | |
4124 memory::release(col_ptrs); | |
4125 | |
4126 // We'll have to copy everything about the other matrix. | |
4127 const uword x_n_rows = x.n_rows; | |
4128 const uword x_n_cols = x.n_cols; | |
4129 const uword x_n_elem = x.n_elem; | |
4130 const uword x_n_nonzero = x.n_nonzero; | |
4131 | |
4132 access::rw(n_rows) = x_n_rows; | |
4133 access::rw(n_cols) = x_n_cols; | |
4134 access::rw(n_elem) = x_n_elem; | |
4135 access::rw(n_nonzero) = x_n_nonzero; | |
4136 | |
4137 access::rw(values) = x.values; | |
4138 access::rw(row_indices) = x.row_indices; | |
4139 access::rw(col_ptrs) = x.col_ptrs; | |
4140 | |
4141 // Set other matrix to empty. | |
4142 access::rw(x.n_rows) = 0; | |
4143 access::rw(x.n_cols) = 0; | |
4144 access::rw(x.n_elem) = 0; | |
4145 access::rw(x.n_nonzero) = 0; | |
4146 | |
4147 access::rw(x.values) = NULL; | |
4148 access::rw(x.row_indices) = NULL; | |
4149 access::rw(x.col_ptrs) = NULL; | |
4150 } | |
4151 } | |
4152 | |
4153 | |
4154 | |
4155 template<typename eT> | |
4156 template<typename T1, typename Functor> | |
4157 arma_hot | |
4158 inline | |
4159 void | |
4160 SpMat<eT>::init_xform(const SpBase<eT,T1>& A, const Functor& func) | |
4161 { | |
4162 arma_extra_debug_sigprint(); | |
4163 | |
4164 // if possible, avoid doing a copy and instead apply func to the generated elements | |
4165 if(SpProxy<T1>::Q_created_by_proxy == true) | |
4166 { | |
4167 (*this) = A.get_ref(); | |
4168 | |
4169 const uword nnz = n_nonzero; | |
4170 | |
4171 eT* t_values = access::rwp(values); | |
4172 | |
4173 for(uword i=0; i < nnz; ++i) | |
4174 { | |
4175 t_values[i] = func(t_values[i]); | |
4176 } | |
4177 } | |
4178 else | |
4179 { | |
4180 init_xform_mt(A.get_ref(), func); | |
4181 } | |
4182 } | |
4183 | |
4184 | |
4185 | |
4186 template<typename eT> | |
4187 template<typename eT2, typename T1, typename Functor> | |
4188 arma_hot | |
4189 inline | |
4190 void | |
4191 SpMat<eT>::init_xform_mt(const SpBase<eT2,T1>& A, const Functor& func) | |
4192 { | |
4193 arma_extra_debug_sigprint(); | |
4194 | |
4195 const SpProxy<T1> P(A.get_ref()); | |
4196 | |
4197 if( (P.is_alias(*this) == true) || (is_SpMat<typename SpProxy<T1>::stored_type>::value == true) ) | |
4198 { | |
4199 // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices; | |
4200 // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary | |
4201 const unwrap_spmat<typename SpProxy<T1>::stored_type> tmp(P.Q); | |
4202 | |
4203 const SpMat<eT2>& x = tmp.M; | |
4204 | |
4205 if(void_ptr(this) != void_ptr(&x)) | |
4206 { | |
4207 init(x.n_rows, x.n_cols); | |
4208 | |
4209 // values and row_indices may not be null. | |
4210 if(values != NULL) | |
4211 { | |
4212 memory::release(values); | |
4213 memory::release(row_indices); | |
4214 } | |
4215 | |
4216 access::rw(values) = memory::acquire_chunked<eT> (x.n_nonzero + 1); | |
4217 access::rw(row_indices) = memory::acquire_chunked<uword>(x.n_nonzero + 1); | |
4218 | |
4219 arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); | |
4220 arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); | |
4221 | |
4222 access::rw(n_nonzero) = x.n_nonzero; | |
4223 } | |
4224 | |
4225 | |
4226 // initialise the elements array with a transformed version of the elements from x | |
4227 | |
4228 const uword nnz = n_nonzero; | |
4229 | |
4230 const eT2* x_values = x.values; | |
4231 eT* t_values = access::rwp(values); | |
4232 | |
4233 for(uword i=0; i < nnz; ++i) | |
4234 { | |
4235 t_values[i] = func(x_values[i]); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT) | |
4236 } | |
4237 } | |
4238 else | |
4239 { | |
4240 init(P.get_n_rows(), P.get_n_cols()); | |
4241 | |
4242 mem_resize(P.get_n_nonzero()); | |
4243 | |
4244 typename SpProxy<T1>::const_iterator_type it = P.begin(); | |
4245 | |
4246 while(it != P.end()) | |
4247 { | |
4248 access::rw(row_indices[it.pos()]) = it.row(); | |
4249 access::rw(values[it.pos()]) = func(*it); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT) | |
4250 ++access::rw(col_ptrs[it.col() + 1]); | |
4251 ++it; | |
4252 } | |
4253 | |
4254 // Now sum column pointers. | |
4255 for(uword c = 1; c <= n_cols; ++c) | |
4256 { | |
4257 access::rw(col_ptrs[c]) += col_ptrs[c - 1]; | |
4258 } | |
4259 } | |
4260 } | |
4261 | |
4262 | |
4263 | |
4264 template<typename eT> | |
4265 inline | |
4266 typename SpMat<eT>::iterator | |
4267 SpMat<eT>::begin() | |
4268 { | |
4269 return iterator(*this); | |
4270 } | |
4271 | |
4272 | |
4273 | |
4274 template<typename eT> | |
4275 inline | |
4276 typename SpMat<eT>::const_iterator | |
4277 SpMat<eT>::begin() const | |
4278 { | |
4279 return const_iterator(*this); | |
4280 } | |
4281 | |
4282 | |
4283 | |
4284 template<typename eT> | |
4285 inline | |
4286 typename SpMat<eT>::iterator | |
4287 SpMat<eT>::end() | |
4288 { | |
4289 return iterator(*this, 0, n_cols, n_nonzero); | |
4290 } | |
4291 | |
4292 | |
4293 | |
4294 template<typename eT> | |
4295 inline | |
4296 typename SpMat<eT>::const_iterator | |
4297 SpMat<eT>::end() const | |
4298 { | |
4299 return const_iterator(*this, 0, n_cols, n_nonzero); | |
4300 } | |
4301 | |
4302 | |
4303 | |
4304 template<typename eT> | |
4305 inline | |
4306 typename SpMat<eT>::iterator | |
4307 SpMat<eT>::begin_col(const uword col_num) | |
4308 { | |
4309 return iterator(*this, 0, col_num); | |
4310 } | |
4311 | |
4312 | |
4313 | |
4314 template<typename eT> | |
4315 inline | |
4316 typename SpMat<eT>::const_iterator | |
4317 SpMat<eT>::begin_col(const uword col_num) const | |
4318 { | |
4319 return const_iterator(*this, 0, col_num); | |
4320 } | |
4321 | |
4322 | |
4323 | |
4324 template<typename eT> | |
4325 inline | |
4326 typename SpMat<eT>::iterator | |
4327 SpMat<eT>::end_col(const uword col_num) | |
4328 { | |
4329 return iterator(*this, 0, col_num + 1); | |
4330 } | |
4331 | |
4332 | |
4333 | |
4334 template<typename eT> | |
4335 inline | |
4336 typename SpMat<eT>::const_iterator | |
4337 SpMat<eT>::end_col(const uword col_num) const | |
4338 { | |
4339 return const_iterator(*this, 0, col_num + 1); | |
4340 } | |
4341 | |
4342 | |
4343 | |
4344 template<typename eT> | |
4345 inline | |
4346 typename SpMat<eT>::row_iterator | |
4347 SpMat<eT>::begin_row(const uword row_num) | |
4348 { | |
4349 return row_iterator(*this, row_num, 0); | |
4350 } | |
4351 | |
4352 | |
4353 | |
4354 template<typename eT> | |
4355 inline | |
4356 typename SpMat<eT>::const_row_iterator | |
4357 SpMat<eT>::begin_row(const uword row_num) const | |
4358 { | |
4359 return const_row_iterator(*this, row_num, 0); | |
4360 } | |
4361 | |
4362 | |
4363 | |
4364 template<typename eT> | |
4365 inline | |
4366 typename SpMat<eT>::row_iterator | |
4367 SpMat<eT>::end_row() | |
4368 { | |
4369 return row_iterator(*this, n_nonzero); | |
4370 } | |
4371 | |
4372 | |
4373 | |
4374 template<typename eT> | |
4375 inline | |
4376 typename SpMat<eT>::const_row_iterator | |
4377 SpMat<eT>::end_row() const | |
4378 { | |
4379 return const_row_iterator(*this, n_nonzero); | |
4380 } | |
4381 | |
4382 | |
4383 | |
4384 template<typename eT> | |
4385 inline | |
4386 typename SpMat<eT>::row_iterator | |
4387 SpMat<eT>::end_row(const uword row_num) | |
4388 { | |
4389 return row_iterator(*this, row_num + 1, 0); | |
4390 } | |
4391 | |
4392 | |
4393 | |
4394 template<typename eT> | |
4395 inline | |
4396 typename SpMat<eT>::const_row_iterator | |
4397 SpMat<eT>::end_row(const uword row_num) const | |
4398 { | |
4399 return const_row_iterator(*this, row_num + 1, 0); | |
4400 } | |
4401 | |
4402 | |
4403 | |
4404 template<typename eT> | |
4405 inline | |
4406 void | |
4407 SpMat<eT>::clear() | |
4408 { | |
4409 if (values) | |
4410 { | |
4411 memory::release(values); | |
4412 memory::release(row_indices); | |
4413 | |
4414 access::rw(values) = memory::acquire_chunked<eT> (1); | |
4415 access::rw(row_indices) = memory::acquire_chunked<uword>(1); | |
4416 | |
4417 access::rw(values[0]) = 0; | |
4418 access::rw(row_indices[0]) = 0; | |
4419 } | |
4420 | |
4421 memory::release(col_ptrs); | |
4422 | |
4423 access::rw(col_ptrs) = memory::acquire<uword>(n_cols + 2); | |
4424 access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max(); | |
4425 | |
4426 arrayops::inplace_set(col_ptrs, eT(0), n_cols + 1); | |
4427 | |
4428 access::rw(n_nonzero) = 0; | |
4429 } | |
4430 | |
4431 | |
4432 | |
4433 template<typename eT> | |
4434 inline | |
4435 bool | |
4436 SpMat<eT>::empty() const | |
4437 { | |
4438 return (n_elem == 0); | |
4439 } | |
4440 | |
4441 | |
4442 | |
4443 template<typename eT> | |
4444 inline | |
4445 uword | |
4446 SpMat<eT>::size() const | |
4447 { | |
4448 return n_elem; | |
4449 } | |
4450 | |
4451 | |
4452 | |
4453 template<typename eT> | |
4454 inline | |
4455 arma_hot | |
4456 arma_warn_unused | |
4457 SpValProxy<SpMat<eT> > | |
4458 SpMat<eT>::get_value(const uword i) | |
4459 { | |
4460 // First convert to the actual location. | |
4461 uword lcol = i / n_rows; // Integer division. | |
4462 uword lrow = i % n_rows; | |
4463 | |
4464 return get_value(lrow, lcol); | |
4465 } | |
4466 | |
4467 | |
4468 | |
4469 template<typename eT> | |
4470 inline | |
4471 arma_hot | |
4472 arma_warn_unused | |
4473 eT | |
4474 SpMat<eT>::get_value(const uword i) const | |
4475 { | |
4476 // First convert to the actual location. | |
4477 uword lcol = i / n_rows; // Integer division. | |
4478 uword lrow = i % n_rows; | |
4479 | |
4480 return get_value(lrow, lcol); | |
4481 } | |
4482 | |
4483 | |
4484 | |
4485 template<typename eT> | |
4486 inline | |
4487 arma_hot | |
4488 arma_warn_unused | |
4489 SpValProxy<SpMat<eT> > | |
4490 SpMat<eT>::get_value(const uword in_row, const uword in_col) | |
4491 { | |
4492 const uword colptr = col_ptrs[in_col]; | |
4493 const uword next_colptr = col_ptrs[in_col + 1]; | |
4494 | |
4495 // Step through the row indices to see if our element exists. | |
4496 for (uword i = colptr; i < next_colptr; ++i) | |
4497 { | |
4498 const uword row_index = row_indices[i]; | |
4499 | |
4500 // First check that we have not stepped past it. | |
4501 if (in_row < row_index) // If we have, then it doesn't exist: return 0. | |
4502 { | |
4503 return SpValProxy<SpMat<eT> >(in_row, in_col, *this); // Proxy for a zero value. | |
4504 } | |
4505 | |
4506 // Now check if we are at the correct place. | |
4507 if (in_row == row_index) // If we are, return a reference to the value. | |
4508 { | |
4509 return SpValProxy<SpMat<eT> >(in_row, in_col, *this, &access::rw(values[i])); | |
4510 } | |
4511 | |
4512 } | |
4513 | |
4514 // We did not find it, so it does not exist: return 0. | |
4515 return SpValProxy<SpMat<eT> >(in_row, in_col, *this); | |
4516 } | |
4517 | |
4518 | |
4519 | |
4520 template<typename eT> | |
4521 inline | |
4522 arma_hot | |
4523 arma_warn_unused | |
4524 eT | |
4525 SpMat<eT>::get_value(const uword in_row, const uword in_col) const | |
4526 { | |
4527 const uword colptr = col_ptrs[in_col]; | |
4528 const uword next_colptr = col_ptrs[in_col + 1]; | |
4529 | |
4530 // Step through the row indices to see if our element exists. | |
4531 for (uword i = colptr; i < next_colptr; ++i) | |
4532 { | |
4533 const uword row_index = row_indices[i]; | |
4534 | |
4535 // First check that we have not stepped past it. | |
4536 if (in_row < row_index) // If we have, then it doesn't exist: return 0. | |
4537 { | |
4538 return eT(0); | |
4539 } | |
4540 | |
4541 // Now check if we are at the correct place. | |
4542 if (in_row == row_index) // If we are, return the value. | |
4543 { | |
4544 return values[i]; | |
4545 } | |
4546 } | |
4547 | |
4548 // We did not find it, so it does not exist: return 0. | |
4549 return eT(0); | |
4550 } | |
4551 | |
4552 | |
4553 | |
4554 /** | |
4555 * Given the index representing which of the nonzero values this is, return its | |
4556 * actual location, either in row/col or just the index. | |
4557 */ | |
4558 template<typename eT> | |
4559 arma_hot | |
4560 arma_inline | |
4561 arma_warn_unused | |
4562 uword | |
4563 SpMat<eT>::get_position(const uword i) const | |
4564 { | |
4565 uword lrow, lcol; | |
4566 | |
4567 get_position(i, lrow, lcol); | |
4568 | |
4569 // Assemble the row/col into the element's location in the matrix. | |
4570 return (lrow + n_rows * lcol); | |
4571 } | |
4572 | |
4573 | |
4574 | |
4575 template<typename eT> | |
4576 arma_hot | |
4577 arma_inline | |
4578 void | |
4579 SpMat<eT>::get_position(const uword i, uword& row_of_i, uword& col_of_i) const | |
4580 { | |
4581 arma_debug_check((i >= n_nonzero), "SpMat::get_position(): index out of bounds"); | |
4582 | |
4583 col_of_i = 0; | |
4584 while (col_ptrs[col_of_i + 1] <= i) | |
4585 { | |
4586 col_of_i++; | |
4587 } | |
4588 | |
4589 row_of_i = row_indices[i]; | |
4590 | |
4591 return; | |
4592 } | |
4593 | |
4594 | |
4595 | |
4596 /** | |
4597 * Add an element at the given position, and return a reference to it. The | |
4598 * element will be set to 0 (unless otherwise specified). If the element | |
4599 * already exists, its value will be overwritten. | |
4600 * | |
4601 * @param in_row Row of new element. | |
4602 * @param in_col Column of new element. | |
4603 * @param in_val Value to set new element to (default 0.0). | |
4604 */ | |
4605 template<typename eT> | |
4606 inline | |
4607 arma_hot | |
4608 arma_warn_unused | |
4609 eT& | |
4610 SpMat<eT>::add_element(const uword in_row, const uword in_col, const eT val) | |
4611 { | |
4612 arma_extra_debug_sigprint(); | |
4613 | |
4614 // We will assume the new element does not exist and begin the search for | |
4615 // where to insert it. If we find that it already exists, we will then | |
4616 // overwrite it. | |
4617 uword colptr = col_ptrs[in_col ]; | |
4618 uword next_colptr = col_ptrs[in_col + 1]; | |
4619 | |
4620 uword pos = colptr; // The position in the matrix of this value. | |
4621 | |
4622 if (colptr != next_colptr) | |
4623 { | |
4624 // There are other elements in this column, so we must find where this | |
4625 // element will fit as compared to those. | |
4626 while (pos < next_colptr && in_row > row_indices[pos]) | |
4627 { | |
4628 pos++; | |
4629 } | |
4630 | |
4631 // We aren't inserting into the last position, so it is still possible | |
4632 // that the element may exist. | |
4633 if (pos != next_colptr && row_indices[pos] == in_row) | |
4634 { | |
4635 // It already exists. Then, just overwrite it. | |
4636 access::rw(values[pos]) = val; | |
4637 | |
4638 return access::rw(values[pos]); | |
4639 } | |
4640 } | |
4641 | |
4642 | |
4643 // | |
4644 // Element doesn't exist, so we have to insert it | |
4645 // | |
4646 | |
4647 // We have to update the rest of the column pointers. | |
4648 for (uword i = in_col + 1; i < n_cols + 1; i++) | |
4649 { | |
4650 access::rw(col_ptrs[i])++; // We are only inserting one new element. | |
4651 } | |
4652 | |
4653 | |
4654 // Figure out the actual amount of memory currently allocated | |
4655 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays | |
4656 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(n_nonzero + 1); | |
4657 | |
4658 // If possible, avoid time-consuming memory allocation | |
4659 if(n_alloc > (n_nonzero + 1)) | |
4660 { | |
4661 arrayops::copy_backwards(access::rwp(values) + pos + 1, values + pos, (n_nonzero - pos) + 1); | |
4662 arrayops::copy_backwards(access::rwp(row_indices) + pos + 1, row_indices + pos, (n_nonzero - pos) + 1); | |
4663 | |
4664 // Insert the new element. | |
4665 access::rw(values[pos]) = val; | |
4666 access::rw(row_indices[pos]) = in_row; | |
4667 | |
4668 access::rw(n_nonzero)++; | |
4669 } | |
4670 else | |
4671 { | |
4672 const uword old_n_nonzero = n_nonzero; | |
4673 | |
4674 access::rw(n_nonzero)++; // Add to count of nonzero elements. | |
4675 | |
4676 // Allocate larger memory. | |
4677 eT* new_values = memory::acquire_chunked<eT> (n_nonzero + 1); | |
4678 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1); | |
4679 | |
4680 // Copy things over, before the new element. | |
4681 if (pos > 0) | |
4682 { | |
4683 arrayops::copy(new_values, values, pos); | |
4684 arrayops::copy(new_row_indices, row_indices, pos); | |
4685 } | |
4686 | |
4687 // Insert the new element. | |
4688 new_values[pos] = val; | |
4689 new_row_indices[pos] = in_row; | |
4690 | |
4691 // Copy the rest of things over (including the extra element at the end). | |
4692 arrayops::copy(new_values + pos + 1, values + pos, (old_n_nonzero - pos) + 1); | |
4693 arrayops::copy(new_row_indices + pos + 1, row_indices + pos, (old_n_nonzero - pos) + 1); | |
4694 | |
4695 // Assign new pointers. | |
4696 memory::release(values); | |
4697 memory::release(row_indices); | |
4698 | |
4699 access::rw(values) = new_values; | |
4700 access::rw(row_indices) = new_row_indices; | |
4701 } | |
4702 | |
4703 return access::rw(values[pos]); | |
4704 } | |
4705 | |
4706 | |
4707 | |
4708 /** | |
4709 * Delete an element at the given position. | |
4710 * | |
4711 * @param in_row Row of element to be deleted. | |
4712 * @param in_col Column of element to be deleted. | |
4713 */ | |
4714 template<typename eT> | |
4715 inline | |
4716 arma_hot | |
4717 void | |
4718 SpMat<eT>::delete_element(const uword in_row, const uword in_col) | |
4719 { | |
4720 arma_extra_debug_sigprint(); | |
4721 | |
4722 // We assume the element exists (although... it may not) and look for its | |
4723 // exact position. If it doesn't exist... well, we don't need to do anything. | |
4724 uword colptr = col_ptrs[in_col]; | |
4725 uword next_colptr = col_ptrs[in_col + 1]; | |
4726 | |
4727 if (colptr != next_colptr) | |
4728 { | |
4729 // There's at least one element in this column. | |
4730 // Let's see if we are one of them. | |
4731 for (uword pos = colptr; pos < next_colptr; pos++) | |
4732 { | |
4733 if (in_row == row_indices[pos]) | |
4734 { | |
4735 const uword old_n_nonzero = n_nonzero; | |
4736 | |
4737 --access::rw(n_nonzero); // Remove one from the count of nonzero elements. | |
4738 | |
4739 // Found it. Now remove it. | |
4740 | |
4741 // Figure out the actual amount of memory currently allocated and the actual amount that will be required | |
4742 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays | |
4743 | |
4744 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(old_n_nonzero + 1); | |
4745 const uword n_alloc_mod = memory::enlarge_to_mult_of_chunksize(n_nonzero + 1); | |
4746 | |
4747 // If possible, avoid time-consuming memory allocation | |
4748 if(n_alloc_mod == n_alloc) | |
4749 { | |
4750 if (pos < n_nonzero) // remember, we decremented n_nonzero | |
4751 { | |
4752 arrayops::copy_forwards(access::rwp(values) + pos, values + pos + 1, (n_nonzero - pos) + 1); | |
4753 arrayops::copy_forwards(access::rwp(row_indices) + pos, row_indices + pos + 1, (n_nonzero - pos) + 1); | |
4754 } | |
4755 } | |
4756 else | |
4757 { | |
4758 // Make new arrays. | |
4759 eT* new_values = memory::acquire_chunked<eT> (n_nonzero + 1); | |
4760 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1); | |
4761 | |
4762 if (pos > 0) | |
4763 { | |
4764 arrayops::copy(new_values, values, pos); | |
4765 arrayops::copy(new_row_indices, row_indices, pos); | |
4766 } | |
4767 | |
4768 arrayops::copy(new_values + pos, values + pos + 1, (n_nonzero - pos) + 1); | |
4769 arrayops::copy(new_row_indices + pos, row_indices + pos + 1, (n_nonzero - pos) + 1); | |
4770 | |
4771 memory::release(values); | |
4772 memory::release(row_indices); | |
4773 | |
4774 access::rw(values) = new_values; | |
4775 access::rw(row_indices) = new_row_indices; | |
4776 } | |
4777 | |
4778 // And lastly, update all the column pointers (decrement by one). | |
4779 for (uword i = in_col + 1; i < n_cols + 1; i++) | |
4780 { | |
4781 --access::rw(col_ptrs[i]); // We only removed one element. | |
4782 } | |
4783 | |
4784 return; // There is nothing left to do. | |
4785 } | |
4786 } | |
4787 } | |
4788 | |
4789 return; // The element does not exist, so there's nothing for us to do. | |
4790 } | |
4791 | |
4792 | |
4793 | |
4794 #ifdef ARMA_EXTRA_SPMAT_MEAT | |
4795 #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT) | |
4796 #endif | |
4797 | |
4798 | |
4799 | |
4800 //! @} |