comparison armadillo-3.900.4/include/armadillo_bits/SpMat_meat.hpp @ 49:1ec0e2823891

Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author Chris Cannam
date Thu, 13 Jun 2013 10:25:24 +0100
parents
children
comparison
equal deleted inserted replaced
48:69251e11a913 49:1ec0e2823891
1 // Copyright (C) 2011-2013 Ryan Curtin
2 // Copyright (C) 2012-2013 Conrad Sanderson
3 // Copyright (C) 2011 Matthew Amidon
4 //
5 // This Source Code Form is subject to the terms of the Mozilla Public
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
8
9 //! \addtogroup SpMat
10 //! @{
11
12 /**
13 * Initialize a sparse matrix with size 0x0 (empty).
14 */
15 template<typename eT>
16 inline
17 SpMat<eT>::SpMat()
18 : n_rows(0)
19 , n_cols(0)
20 , n_elem(0)
21 , n_nonzero(0)
22 , vec_state(0)
23 , values(memory::acquire_chunked<eT>(1))
24 , row_indices(memory::acquire_chunked<uword>(1))
25 , col_ptrs(memory::acquire<uword>(2))
26 {
27 arma_extra_debug_sigprint_this(this);
28
29 access::rw(values[0]) = 0;
30 access::rw(row_indices[0]) = 0;
31
32 access::rw(col_ptrs[0]) = 0; // No elements.
33 access::rw(col_ptrs[1]) = std::numeric_limits<uword>::max();
34 }
35
36
37
38 /**
39 * Clean up the memory of a sparse matrix and destruct it.
40 */
41 template<typename eT>
42 inline
43 SpMat<eT>::~SpMat()
44 {
45 arma_extra_debug_sigprint_this(this);
46
47 // If necessary, release the memory.
48 if (values)
49 {
50 // values being non-NULL implies row_indices is non-NULL.
51 memory::release(access::rw(values));
52 memory::release(access::rw(row_indices));
53 }
54
55 // Column pointers always must be deleted.
56 memory::release(access::rw(col_ptrs));
57 }
58
59
60
61 /**
62 * Constructor with size given.
63 */
64 template<typename eT>
65 inline
66 SpMat<eT>::SpMat(const uword in_rows, const uword in_cols)
67 : n_rows(0)
68 , n_cols(0)
69 , n_elem(0)
70 , n_nonzero(0)
71 , vec_state(0)
72 , values(NULL)
73 , row_indices(NULL)
74 , col_ptrs(NULL)
75 {
76 arma_extra_debug_sigprint_this(this);
77
78 init(in_rows, in_cols);
79 }
80
81
82
83 /**
84 * Assemble from text.
85 */
86 template<typename eT>
87 inline
88 SpMat<eT>::SpMat(const char* text)
89 : n_rows(0)
90 , n_cols(0)
91 , n_elem(0)
92 , n_nonzero(0)
93 , vec_state(0)
94 , values(NULL)
95 , row_indices(NULL)
96 , col_ptrs(NULL)
97 {
98 arma_extra_debug_sigprint_this(this);
99
100 init(std::string(text));
101 }
102
103
104
105 template<typename eT>
106 inline
107 const SpMat<eT>&
108 SpMat<eT>::operator=(const char* text)
109 {
110 arma_extra_debug_sigprint();
111
112 init(std::string(text));
113 }
114
115
116
117 template<typename eT>
118 inline
119 SpMat<eT>::SpMat(const std::string& text)
120 : n_rows(0)
121 , n_cols(0)
122 , n_elem(0)
123 , n_nonzero(0)
124 , vec_state(0)
125 , values(NULL)
126 , row_indices(NULL)
127 , col_ptrs(NULL)
128 {
129 arma_extra_debug_sigprint();
130
131 init(text);
132 }
133
134
135
136 template<typename eT>
137 inline
138 const SpMat<eT>&
139 SpMat<eT>::operator=(const std::string& text)
140 {
141 arma_extra_debug_sigprint();
142
143 init(text);
144 }
145
146
147
148 template<typename eT>
149 inline
150 SpMat<eT>::SpMat(const SpMat<eT>& x)
151 : n_rows(0)
152 , n_cols(0)
153 , n_elem(0)
154 , n_nonzero(0)
155 , vec_state(0)
156 , values(NULL)
157 , row_indices(NULL)
158 , col_ptrs(NULL)
159 {
160 arma_extra_debug_sigprint_this(this);
161
162 init(x);
163 }
164
165
166
167 //! Insert a large number of values at once.
168 //! locations.row[0] should be row indices, locations.row[1] should be column indices,
169 //! and values should be the corresponding values.
170 //! If sort_locations is false, then it is assumed that the locations and values
171 //! are already sorted in column-major ordering.
172 template<typename eT>
173 template<typename T1, typename T2>
174 inline
175 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const bool sort_locations)
176 : n_rows(0)
177 , n_cols(0)
178 , n_elem(0)
179 , n_nonzero(0)
180 , vec_state(0)
181 , values(NULL)
182 , row_indices(NULL)
183 , col_ptrs(NULL)
184 {
185 arma_extra_debug_sigprint_this(this);
186
187 const unwrap<T1> locs_tmp( locations_expr.get_ref() );
188 const Mat<uword>& locs = locs_tmp.M;
189
190 const unwrap<T2> vals_tmp( vals_expr.get_ref() );
191 const Mat<eT>& vals = vals_tmp.M;
192
193 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
194
195 arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values");
196
197 // If there are no elements in the list, max() will fail.
198 if (locs.n_cols == 0)
199 {
200 init(0, 0);
201 return;
202 }
203
204 arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows");
205
206 // Automatically determine size (and check if it's sorted).
207 uvec bounds = arma::max(locs, 1);
208 init(bounds[0] + 1, bounds[1] + 1);
209
210 // Resize to correct number of elements.
211 mem_resize(vals.n_elem);
212
213 // Reset column pointers to zero.
214 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1);
215
216 bool actually_sorted = true;
217 if(sort_locations == true)
218 {
219 // sort_index() uses std::sort() which may use quicksort... so we better
220 // make sure it's not already sorted before taking an O(N^2) sort penalty.
221 for (uword i = 1; i < locs.n_cols; ++i)
222 {
223 if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1)))
224 {
225 actually_sorted = false;
226 break;
227 }
228 }
229
230 if(actually_sorted == false)
231 {
232 // This may not be the fastest possible implementation but it maximizes code reuse.
233 Col<uword> abslocs(locs.n_cols);
234
235 for (uword i = 0; i < locs.n_cols; ++i)
236 {
237 abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i);
238 }
239
240 // Now we will sort with sort_index().
241 uvec sorted_indices = sort_index(abslocs); // Ascending sort.
242
243 // Now we add the elements in this sorted order.
244 for (uword i = 0; i < sorted_indices.n_elem; ++i)
245 {
246 arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index");
247 arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index");
248
249 access::rw(values[i]) = vals[sorted_indices[i]];
250 access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]);
251
252 access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++;
253 }
254 }
255 }
256
257 if( (sort_locations == false) || (actually_sorted == true) )
258 {
259 // Now set the values and row indices correctly.
260 // Increment the column pointers in each column (so they are column "counts").
261 for (uword i = 0; i < vals.n_elem; ++i)
262 {
263 arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index");
264 arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index");
265
266 // Check ordering in debug mode.
267 if(i > 0)
268 {
269 arma_debug_check
270 (
271 ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ),
272 "SpMat::SpMat(): out of order points; either pass sort_locations = true, or sort points in column-major ordering"
273 );
274 arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list");
275 }
276
277 access::rw(values[i]) = vals[i];
278 access::rw(row_indices[i]) = locs.at(0, i);
279
280 access::rw(col_ptrs[locs.at(1, i) + 1])++;
281 }
282 }
283
284 // Now fix the column pointers.
285 for (uword i = 0; i <= n_cols; ++i)
286 {
287 access::rw(col_ptrs[i + 1]) += col_ptrs[i];
288 }
289 }
290
291
292
293 //! Insert a large number of values at once.
294 //! locations.row[0] should be row indices, locations.row[1] should be column indices,
295 //! and values should be the corresponding values.
296 //! If sort_locations is false, then it is assumed that the locations and values
297 //! are already sorted in column-major ordering.
298 //! In this constructor the size is explicitly given.
299 template<typename eT>
300 template<typename T1, typename T2>
301 inline
302 SpMat<eT>::SpMat(const Base<uword,T1>& locations_expr, const Base<eT,T2>& vals_expr, const uword in_n_rows, const uword in_n_cols, const bool sort_locations)
303 : n_rows(0)
304 , n_cols(0)
305 , n_elem(0)
306 , n_nonzero(0)
307 , vec_state(0)
308 , values(NULL)
309 , row_indices(NULL)
310 , col_ptrs(NULL)
311 {
312 arma_extra_debug_sigprint_this(this);
313
314 init(in_n_rows, in_n_cols);
315
316 const unwrap<T1> locs_tmp( locations_expr.get_ref() );
317 const Mat<uword>& locs = locs_tmp.M;
318
319 const unwrap<T2> vals_tmp( vals_expr.get_ref() );
320 const Mat<eT>& vals = vals_tmp.M;
321
322 arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" );
323
324 arma_debug_check((locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows");
325
326 arma_debug_check((locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values");
327
328 // Resize to correct number of elements.
329 mem_resize(vals.n_elem);
330
331 // Reset column pointers to zero.
332 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1);
333
334 bool actually_sorted = true;
335 if(sort_locations == true)
336 {
337 // sort_index() uses std::sort() which may use quicksort... so we better
338 // make sure it's not already sorted before taking an O(N^2) sort penalty.
339 for (uword i = 1; i < locs.n_cols; ++i)
340 {
341 if ((locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) <= locs.at(0, i - 1)))
342 {
343 actually_sorted = false;
344 break;
345 }
346 }
347
348 if(actually_sorted == false)
349 {
350 // This may not be the fastest possible implementation but it maximizes code reuse.
351 Col<uword> abslocs(locs.n_cols);
352
353 for (uword i = 0; i < locs.n_cols; ++i)
354 {
355 abslocs[i] = locs.at(1, i) * n_rows + locs.at(0, i);
356 }
357
358 // Now we will sort with sort_index().
359 uvec sorted_indices = sort_index(abslocs); // Ascending sort.
360
361 // Now we add the elements in this sorted order.
362 for (uword i = 0; i < sorted_indices.n_elem; ++i)
363 {
364 arma_debug_check((locs.at(0, sorted_indices[i]) >= n_rows), "SpMat::SpMat(): invalid row index");
365 arma_debug_check((locs.at(1, sorted_indices[i]) >= n_cols), "SpMat::SpMat(): invalid column index");
366
367 access::rw(values[i]) = vals[sorted_indices[i]];
368 access::rw(row_indices[i]) = locs.at(0, sorted_indices[i]);
369
370 access::rw(col_ptrs[locs.at(1, sorted_indices[i]) + 1])++;
371 }
372 }
373 }
374
375 if( (sort_locations == false) || (actually_sorted == true) )
376 {
377 // Now set the values and row indices correctly.
378 // Increment the column pointers in each column (so they are column "counts").
379 for (uword i = 0; i < vals.n_elem; ++i)
380 {
381 arma_debug_check((locs.at(0, i) >= n_rows), "SpMat::SpMat(): invalid row index");
382 arma_debug_check((locs.at(1, i) >= n_cols), "SpMat::SpMat(): invalid column index");
383
384 // Check ordering in debug mode.
385 if(i > 0)
386 {
387 arma_debug_check
388 (
389 ( (locs.at(1, i) < locs.at(1, i - 1)) || (locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) < locs.at(0, i - 1)) ),
390 "SpMat::SpMat(): out of order points; either pass sort_locations = true or sort points in column-major ordering"
391 );
392 arma_debug_check((locs.at(1, i) == locs.at(1, i - 1) && locs.at(0, i) == locs.at(0, i - 1)), "SpMat::SpMat(): two identical point locations in list");
393 }
394
395 access::rw(values[i]) = vals[i];
396 access::rw(row_indices[i]) = locs.at(0, i);
397
398 access::rw(col_ptrs[locs.at(1, i) + 1])++;
399 }
400 }
401
402 // Now fix the column pointers.
403 for (uword i = 0; i <= n_cols; ++i)
404 {
405 access::rw(col_ptrs[i + 1]) += col_ptrs[i];
406 }
407 }
408
409
410
411 /**
412 * Simple operators with plain values. These operate on every value in the
413 * matrix, so a sparse matrix += 1 will turn all those zeroes into ones. Be
414 * careful and make sure that's what you really want!
415 */
416 template<typename eT>
417 inline
418 const SpMat<eT>&
419 SpMat<eT>::operator=(const eT val)
420 {
421 arma_extra_debug_sigprint();
422
423 // Resize to 1x1 then set that to the right value.
424 init(1, 1); // Sets col_ptrs to 0.
425 mem_resize(1); // One element.
426
427 // Manually set element.
428 access::rw(values[0]) = val;
429 access::rw(row_indices[0]) = 0;
430 access::rw(col_ptrs[1]) = 1;
431
432 return *this;
433 }
434
435
436
437 template<typename eT>
438 inline
439 const SpMat<eT>&
440 SpMat<eT>::operator*=(const eT val)
441 {
442 arma_extra_debug_sigprint();
443
444 if(val == eT(0))
445 {
446 // Everything will be zero.
447 init(n_rows, n_cols);
448 return *this;
449 }
450
451 arrayops::inplace_mul( access::rwp(values), val, n_nonzero );
452
453 return *this;
454 }
455
456
457
458 template<typename eT>
459 inline
460 const SpMat<eT>&
461 SpMat<eT>::operator/=(const eT val)
462 {
463 arma_extra_debug_sigprint();
464
465 arma_debug_check( (val == eT(0)), "element-wise division: division by zero" );
466
467 arrayops::inplace_div( access::rwp(values), val, n_nonzero );
468
469 return *this;
470 }
471
472
473
474 template<typename eT>
475 inline
476 const SpMat<eT>&
477 SpMat<eT>::operator=(const SpMat<eT>& x)
478 {
479 arma_extra_debug_sigprint();
480
481 init(x);
482
483 return *this;
484 }
485
486
487
488 template<typename eT>
489 inline
490 const SpMat<eT>&
491 SpMat<eT>::operator+=(const SpMat<eT>& x)
492 {
493 arma_extra_debug_sigprint();
494
495 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "addition");
496
497 // Iterate over nonzero values of other matrix.
498 for (const_iterator it = x.begin(); it != x.end(); it++)
499 {
500 get_value(it.row(), it.col()) += *it;
501 }
502
503 return *this;
504 }
505
506
507
508 template<typename eT>
509 inline
510 const SpMat<eT>&
511 SpMat<eT>::operator-=(const SpMat<eT>& x)
512 {
513 arma_extra_debug_sigprint();
514
515 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "subtraction");
516
517 // Iterate over nonzero values of other matrix.
518 for (const_iterator it = x.begin(); it != x.end(); it++)
519 {
520 get_value(it.row(), it.col()) -= *it;
521 }
522
523 return *this;
524 }
525
526
527
528 template<typename eT>
529 inline
530 const SpMat<eT>&
531 SpMat<eT>::operator*=(const SpMat<eT>& y)
532 {
533 arma_extra_debug_sigprint();
534
535 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication");
536
537 SpMat<eT> z;
538 z = (*this) * y;
539 steal_mem(z);
540
541 return *this;
542 }
543
544
545
546 // This is in-place element-wise matrix multiplication.
547 template<typename eT>
548 inline
549 const SpMat<eT>&
550 SpMat<eT>::operator%=(const SpMat<eT>& x)
551 {
552 arma_extra_debug_sigprint();
553
554 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication");
555
556 // We can do this with two iterators rather simply.
557 iterator it = begin();
558 const_iterator x_it = x.begin();
559
560 while (it != end() && x_it != x.end())
561 {
562 // One of these will be further advanced than the other (or they will be at the same place).
563 if ((it.row() == x_it.row()) && (it.col() == x_it.col()))
564 {
565 // There is an element at this place in both matrices. Multiply.
566 (*it) *= (*x_it);
567
568 // Now move on to the next position.
569 it++;
570 x_it++;
571 }
572
573 else if ((it.col() < x_it.col()) || ((it.col() == x_it.col()) && (it.row() < x_it.row())))
574 {
575 // This case is when our matrix has an element which the other matrix does not.
576 // So we must delete this element.
577 (*it) = 0;
578
579 // Because we have deleted the element, we now have to manually set the position...
580 it.internal_pos--;
581
582 // Now we can increment our iterator.
583 it++;
584 }
585
586 else /* if our iterator is ahead of the other matrix */
587 {
588 // In this case we don't need to set anything to 0; our element is already 0.
589 // We can just increment the iterator of the other matrix.
590 x_it++;
591 }
592
593 }
594
595 // If we are not at the end of our matrix, then we must terminate the remaining elements.
596 while (it != end())
597 {
598 (*it) = 0;
599
600 // Hack to manually set the position right...
601 it.internal_pos--;
602 it++; // ...and then an increment.
603 }
604
605 return *this;
606 }
607
608
609
610 // Construct a complex matrix out of two non-complex matrices
611 template<typename eT>
612 template<typename T1, typename T2>
613 inline
614 SpMat<eT>::SpMat
615 (
616 const SpBase<typename SpMat<eT>::pod_type, T1>& A,
617 const SpBase<typename SpMat<eT>::pod_type, T2>& B
618 )
619 : n_rows(0)
620 , n_cols(0)
621 , n_elem(0)
622 , n_nonzero(0)
623 , vec_state(0)
624 , values(NULL) // extra element is set when mem_resize is called
625 , row_indices(NULL)
626 , col_ptrs(NULL)
627 {
628 arma_extra_debug_sigprint();
629
630 typedef typename T1::elem_type T;
631
632 // Make sure eT is complex and T is not (compile-time check).
633 arma_type_check(( is_complex<eT>::value == false ));
634 arma_type_check(( is_complex< T>::value == true ));
635
636 // Compile-time abort if types are not compatible.
637 arma_type_check(( is_same_type< std::complex<T>, eT >::value == false ));
638
639 const unwrap_spmat<T1> tmp1(A.get_ref());
640 const unwrap_spmat<T2> tmp2(B.get_ref());
641
642 const SpMat<T>& X = tmp1.M;
643 const SpMat<T>& Y = tmp2.M;
644
645 arma_debug_assert_same_size(X.n_rows, X.n_cols, Y.n_rows, Y.n_cols, "SpMat()");
646
647 const uword l_n_rows = X.n_rows;
648 const uword l_n_cols = X.n_cols;
649
650 // Set size of matrix correctly.
651 init(l_n_rows, l_n_cols);
652 mem_resize(n_unique(X, Y, op_n_unique_count()));
653
654 // Now on a second iteration, fill it.
655 typename SpMat<T>::const_iterator x_it = X.begin();
656 typename SpMat<T>::const_iterator x_end = X.end();
657
658 typename SpMat<T>::const_iterator y_it = Y.begin();
659 typename SpMat<T>::const_iterator y_end = Y.end();
660
661 uword cur_pos = 0;
662
663 while ((x_it != x_end) || (y_it != y_end))
664 {
665 if(x_it == y_it) // if we are at the same place
666 {
667 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, (T) *y_it);
668 access::rw(row_indices[cur_pos]) = x_it.row();
669 ++access::rw(col_ptrs[x_it.col() + 1]);
670
671 ++x_it;
672 ++y_it;
673 }
674 else
675 {
676 if((x_it.col() < y_it.col()) || ((x_it.col() == y_it.col()) && (x_it.row() < y_it.row()))) // if y is closer to the end
677 {
678 access::rw(values[cur_pos]) = std::complex<T>((T) *x_it, T(0));
679 access::rw(row_indices[cur_pos]) = x_it.row();
680 ++access::rw(col_ptrs[x_it.col() + 1]);
681
682 ++x_it;
683 }
684 else // x is closer to the end
685 {
686 access::rw(values[cur_pos]) = std::complex<T>(T(0), (T) *y_it);
687 access::rw(row_indices[cur_pos]) = y_it.row();
688 ++access::rw(col_ptrs[y_it.col() + 1]);
689
690 ++y_it;
691 }
692 }
693
694 ++cur_pos;
695 }
696
697 // Now fix the column pointers; they are supposed to be a sum.
698 for (uword c = 1; c <= n_cols; ++c)
699 {
700 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
701 }
702
703 }
704
705
706
707 template<typename eT>
708 inline
709 const SpMat<eT>&
710 SpMat<eT>::operator/=(const SpMat<eT>& x)
711 {
712 arma_extra_debug_sigprint();
713
714 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
715
716 // If you use this method, you are probably stupid or misguided, but for compatibility with Mat, we have implemented it anyway.
717 // We have to loop over every element, which is not good. In fact, it makes me physically sad to write this.
718 for(uword c = 0; c < n_cols; ++c)
719 {
720 for(uword r = 0; r < n_rows; ++r)
721 {
722 at(r, c) /= x.at(r, c);
723 }
724 }
725
726 return *this;
727 }
728
729
730
731 template<typename eT>
732 template<typename T1>
733 inline
734 SpMat<eT>::SpMat(const Base<eT, T1>& x)
735 : n_rows(0)
736 , n_cols(0)
737 , n_elem(0)
738 , n_nonzero(0)
739 , vec_state(0)
740 , values(NULL) // extra element is set when mem_resize is called in operator=()
741 , row_indices(NULL)
742 , col_ptrs(NULL)
743 {
744 arma_extra_debug_sigprint_this(this);
745
746 (*this).operator=(x);
747 }
748
749
750
751 template<typename eT>
752 template<typename T1>
753 inline
754 const SpMat<eT>&
755 SpMat<eT>::operator=(const Base<eT, T1>& x)
756 {
757 arma_extra_debug_sigprint();
758
759 const Proxy<T1> p(x.get_ref());
760
761 const uword x_n_rows = p.get_n_rows();
762 const uword x_n_cols = p.get_n_cols();
763 const uword x_n_elem = p.get_n_elem();
764
765 init(x_n_rows, x_n_cols);
766
767 // Count number of nonzero elements in base object.
768 uword n = 0;
769 if(Proxy<T1>::prefer_at_accessor == true)
770 {
771 for(uword j = 0; j < x_n_cols; ++j)
772 for(uword i = 0; i < x_n_rows; ++i)
773 {
774 if(p.at(i, j) != eT(0)) { ++n; }
775 }
776 }
777 else
778 {
779 for(uword i = 0; i < x_n_elem; ++i)
780 {
781 if(p[i] != eT(0)) { ++n; }
782 }
783 }
784
785 mem_resize(n);
786
787 // Now the memory is resized correctly; add nonzero elements.
788 n = 0;
789 for(uword j = 0; j < x_n_cols; ++j)
790 for(uword i = 0; i < x_n_rows; ++i)
791 {
792 const eT val = p.at(i, j);
793
794 if(val != eT(0))
795 {
796 access::rw(values[n]) = val;
797 access::rw(row_indices[n]) = i;
798 access::rw(col_ptrs[j + 1])++;
799 ++n;
800 }
801 }
802
803 // Sum column counts to be column pointers.
804 for(uword c = 1; c <= n_cols; ++c)
805 {
806 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
807 }
808
809 return *this;
810 }
811
812
813
814 template<typename eT>
815 template<typename T1>
816 inline
817 const SpMat<eT>&
818 SpMat<eT>::operator*=(const Base<eT, T1>& y)
819 {
820 arma_extra_debug_sigprint();
821
822 const Proxy<T1> p(y.get_ref());
823
824 arma_debug_assert_mul_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "matrix multiplication");
825
826 // We assume the matrix structure is such that we will end up with a sparse
827 // matrix. Assuming that every entry in the dense matrix is nonzero (which is
828 // a fairly valid assumption), each row with any nonzero elements in it (in this
829 // matrix) implies an entire nonzero column. Therefore, we iterate over all
830 // the row_indices and count the number of rows with any elements in them
831 // (using the quasi-linked-list idea from SYMBMM -- see operator_times.hpp).
832 podarray<uword> index(n_rows);
833 index.fill(n_rows); // Fill with invalid links.
834
835 uword last_index = n_rows + 1;
836 for(uword i = 0; i < n_nonzero; ++i)
837 {
838 if(index[row_indices[i]] == n_rows)
839 {
840 index[row_indices[i]] = last_index;
841 last_index = row_indices[i];
842 }
843 }
844
845 // Now count the number of rows which have nonzero elements.
846 uword nonzero_rows = 0;
847 while(last_index != n_rows + 1)
848 {
849 ++nonzero_rows;
850 last_index = index[last_index];
851 }
852
853 SpMat<eT> z(n_rows, p.get_n_cols());
854
855 z.mem_resize(nonzero_rows * p.get_n_cols()); // upper bound on size
856
857 // Now we have to fill all the elements using a modification of the NUMBMM algorithm.
858 uword cur_pos = 0;
859
860 podarray<eT> partial_sums(n_rows);
861 partial_sums.zeros();
862
863 for(uword lcol = 0; lcol < n_cols; ++lcol)
864 {
865 const_iterator it = begin();
866
867 while(it != end())
868 {
869 const eT value = (*it);
870
871 partial_sums[it.row()] += (value * p.at(it.col(), lcol));
872
873 ++it;
874 }
875
876 // Now add all partial sums to the matrix.
877 for(uword i = 0; i < n_rows; ++i)
878 {
879 if(partial_sums[i] != eT(0))
880 {
881 access::rw(z.values[cur_pos]) = partial_sums[i];
882 access::rw(z.row_indices[cur_pos]) = i;
883 ++access::rw(z.col_ptrs[lcol + 1]);
884 //printf("colptr %d now %d\n", lcol + 1, z.col_ptrs[lcol + 1]);
885 ++cur_pos;
886 partial_sums[i] = 0; // Would it be faster to do this in batch later?
887 }
888 }
889 }
890
891 // Now fix the column pointers.
892 for(uword c = 1; c <= z.n_cols; ++c)
893 {
894 access::rw(z.col_ptrs[c]) += z.col_ptrs[c - 1];
895 }
896
897 // Resize to final correct size.
898 z.mem_resize(z.col_ptrs[z.n_cols]);
899
900 // Now take the memory of the temporary matrix.
901 steal_mem(z);
902
903 return *this;
904 }
905
906
907
908 /**
909 * Don't use this function. It's not mathematically well-defined and wastes
910 * cycles to trash all your data. This is dumb.
911 */
912 template<typename eT>
913 template<typename T1>
914 inline
915 const SpMat<eT>&
916 SpMat<eT>::operator/=(const Base<eT, T1>& x)
917 {
918 arma_extra_debug_sigprint();
919
920 SpMat<eT> tmp = (*this) / x.get_ref();
921
922 steal_mem(tmp);
923
924 return *this;
925 }
926
927
928
929 template<typename eT>
930 template<typename T1>
931 inline
932 const SpMat<eT>&
933 SpMat<eT>::operator%=(const Base<eT, T1>& x)
934 {
935 arma_extra_debug_sigprint();
936
937 const Proxy<T1> p(x.get_ref());
938
939 arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication");
940
941 // Count the number of elements we will need.
942 SpMat<eT> tmp(n_rows, n_cols);
943 const_iterator it = begin();
944 uword new_n_nonzero = 0;
945
946 while(it != end())
947 {
948 // prefer_at_accessor == false can't save us any work here
949 if(((*it) * p.at(it.row(), it.col())) != eT(0))
950 {
951 ++new_n_nonzero;
952 }
953 ++it;
954 }
955
956 // Resize.
957 tmp.mem_resize(new_n_nonzero);
958
959 const_iterator c_it = begin();
960 uword cur_pos = 0;
961 while(c_it != end())
962 {
963 // prefer_at_accessor == false can't save us any work here
964 const eT val = (*c_it) * p.at(c_it.row(), c_it.col());
965 if(val != eT(0))
966 {
967 access::rw(tmp.values[cur_pos]) = val;
968 access::rw(tmp.row_indices[cur_pos]) = c_it.row();
969 ++access::rw(tmp.col_ptrs[c_it.col() + 1]);
970 ++cur_pos;
971 }
972
973 ++c_it;
974 }
975
976 // Fix column pointers.
977 for(uword c = 1; c <= n_cols; ++c)
978 {
979 access::rw(tmp.col_ptrs[c]) += tmp.col_ptrs[c - 1];
980 }
981
982 steal_mem(tmp);
983
984 return *this;
985 }
986
987
988
989 /**
990 * Functions on subviews.
991 */
992 template<typename eT>
993 inline
994 SpMat<eT>::SpMat(const SpSubview<eT>& X)
995 : n_rows(0)
996 , n_cols(0)
997 , n_elem(0)
998 , n_nonzero(0)
999 , vec_state(0)
1000 , values(NULL) // extra element added when mem_resize is called
1001 , row_indices(NULL)
1002 , col_ptrs(NULL)
1003 {
1004 arma_extra_debug_sigprint_this(this);
1005
1006 (*this).operator=(X);
1007 }
1008
1009
1010
1011 template<typename eT>
1012 inline
1013 const SpMat<eT>&
1014 SpMat<eT>::operator=(const SpSubview<eT>& X)
1015 {
1016 arma_extra_debug_sigprint();
1017
1018 const uword in_n_cols = X.n_cols;
1019 const uword in_n_rows = X.n_rows;
1020
1021 const bool alias = (this == &(X.m));
1022
1023 if(alias == false)
1024 {
1025 init(in_n_rows, in_n_cols);
1026
1027 const uword x_n_nonzero = X.n_nonzero;
1028
1029 mem_resize(x_n_nonzero);
1030
1031 typename SpSubview<eT>::const_iterator it = X.begin();
1032
1033 while(it != X.end())
1034 {
1035 access::rw(row_indices[it.pos()]) = it.row();
1036 access::rw(values[it.pos()]) = (*it);
1037 ++access::rw(col_ptrs[it.col() + 1]);
1038 ++it;
1039 }
1040
1041 // Now sum column pointers.
1042 for(uword c = 1; c <= n_cols; ++c)
1043 {
1044 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
1045 }
1046 }
1047 else
1048 {
1049 // Create it in a temporary.
1050 SpMat<eT> tmp(X);
1051
1052 steal_mem(tmp);
1053 }
1054
1055 return *this;
1056 }
1057
1058
1059
1060 template<typename eT>
1061 inline
1062 const SpMat<eT>&
1063 SpMat<eT>::operator+=(const SpSubview<eT>& X)
1064 {
1065 arma_extra_debug_sigprint();
1066
1067 arma_debug_assert_same_size(n_rows, n_cols, X.n_rows, X.n_cols, "addition");
1068
1069 typename SpSubview<eT>::const_iterator it = X.begin();
1070
1071 while(it != X.end())
1072 {
1073 at(it.row(), it.col()) += (*it);
1074 ++it;
1075 }
1076
1077 return *this;
1078 }
1079
1080
1081
1082 template<typename eT>
1083 inline
1084 const SpMat<eT>&
1085 SpMat<eT>::operator-=(const SpSubview<eT>& X)
1086 {
1087 arma_extra_debug_sigprint();
1088
1089 arma_debug_assert_same_size(n_rows, n_cols, X.n_rows, X.n_cols, "subtraction");
1090
1091 typename SpSubview<eT>::const_iterator it = X.begin();
1092
1093 while(it != X.end())
1094 {
1095 at(it.row(), it.col()) -= (*it);
1096 ++it;
1097 }
1098
1099 return *this;
1100 }
1101
1102
1103
1104 template<typename eT>
1105 inline
1106 const SpMat<eT>&
1107 SpMat<eT>::operator*=(const SpSubview<eT>& y)
1108 {
1109 arma_extra_debug_sigprint();
1110
1111 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication");
1112
1113 // Cannot be done in-place (easily).
1114 SpMat<eT> z = (*this) * y;
1115 steal_mem(z);
1116
1117 return *this;
1118 }
1119
1120
1121
1122 template<typename eT>
1123 inline
1124 const SpMat<eT>&
1125 SpMat<eT>::operator%=(const SpSubview<eT>& x)
1126 {
1127 arma_extra_debug_sigprint();
1128
1129 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication");
1130
1131 iterator it = begin();
1132 typename SpSubview<eT>::const_iterator xit = x.begin();
1133
1134 while((it != end()) || (xit != x.end()))
1135 {
1136 if((xit.row() == it.row()) && (xit.col() == it.col()))
1137 {
1138 (*it) *= (*xit);
1139 ++it;
1140 ++xit;
1141 }
1142 else
1143 {
1144 if((xit.col() > it.col()) || ((xit.col() == it.col()) && (xit.row() > it.row())))
1145 {
1146 // xit is "ahead"
1147 (*it) = eT(0); // erase element; x has a zero here
1148 it.internal_pos--; // update iterator so it still works
1149 ++it;
1150 }
1151 else
1152 {
1153 // it is "ahead"
1154 ++xit;
1155 }
1156 }
1157 }
1158
1159 return *this;
1160 }
1161
1162
1163 template<typename eT>
1164 inline
1165 const SpMat<eT>&
1166 SpMat<eT>::operator/=(const SpSubview<eT>& x)
1167 {
1168 arma_extra_debug_sigprint();
1169
1170 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
1171
1172 // There is no pretty way to do this.
1173 for(uword elem = 0; elem < n_elem; elem++)
1174 {
1175 at(elem) /= x(elem);
1176 }
1177
1178 return *this;
1179 }
1180
1181
1182
1183 /**
1184 * Operators on regular subviews.
1185 */
1186 template<typename eT>
1187 inline
1188 SpMat<eT>::SpMat(const subview<eT>& x)
1189 : n_rows(0)
1190 , n_cols(0)
1191 , n_elem(0)
1192 , n_nonzero(0)
1193 , vec_state(0)
1194 , values(NULL) // extra value set in operator=()
1195 , row_indices(NULL)
1196 , col_ptrs(NULL)
1197 {
1198 arma_extra_debug_sigprint_this(this);
1199
1200 (*this).operator=(x);
1201 }
1202
1203
1204
1205 template<typename eT>
1206 inline
1207 const SpMat<eT>&
1208 SpMat<eT>::operator=(const subview<eT>& x)
1209 {
1210 arma_extra_debug_sigprint();
1211
1212 const uword x_n_rows = x.n_rows;
1213 const uword x_n_cols = x.n_cols;
1214
1215 // Set the size correctly.
1216 init(x_n_rows, x_n_cols);
1217
1218 // Count number of nonzero elements.
1219 uword n = 0;
1220 for(uword c = 0; c < x_n_cols; ++c)
1221 {
1222 for(uword r = 0; r < x_n_rows; ++r)
1223 {
1224 if(x.at(r, c) != eT(0))
1225 {
1226 ++n;
1227 }
1228 }
1229 }
1230
1231 // Resize memory appropriately.
1232 mem_resize(n);
1233
1234 n = 0;
1235 for(uword c = 0; c < x_n_cols; ++c)
1236 {
1237 for(uword r = 0; r < x_n_rows; ++r)
1238 {
1239 const eT val = x.at(r, c);
1240
1241 if(val != eT(0))
1242 {
1243 access::rw(values[n]) = val;
1244 access::rw(row_indices[n]) = r;
1245 ++access::rw(col_ptrs[c + 1]);
1246 ++n;
1247 }
1248 }
1249 }
1250
1251 // Fix column counts into column pointers.
1252 for(uword c = 1; c <= n_cols; ++c)
1253 {
1254 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
1255 }
1256
1257 return *this;
1258 }
1259
1260
1261
1262 template<typename eT>
1263 inline
1264 const SpMat<eT>&
1265 SpMat<eT>::operator+=(const subview<eT>& x)
1266 {
1267 arma_extra_debug_sigprint();
1268
1269 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "addition");
1270
1271 // Loop over every element. This could probably be written in a more
1272 // efficient way, by calculating the number of nonzero elements the output
1273 // matrix will have, allocating the memory correctly, and then filling the
1274 // matrix correctly. However... for now, this works okay.
1275 for(uword lcol = 0; lcol < n_cols; ++lcol)
1276 for(uword lrow = 0; lrow < n_rows; ++lrow)
1277 {
1278 at(lrow, lcol) += x.at(lrow, lcol);
1279 }
1280
1281 return *this;
1282 }
1283
1284
1285
1286 template<typename eT>
1287 inline
1288 const SpMat<eT>&
1289 SpMat<eT>::operator-=(const subview<eT>& x)
1290 {
1291 arma_extra_debug_sigprint();
1292
1293 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "subtraction");
1294
1295 // Loop over every element.
1296 for(uword lcol = 0; lcol < n_cols; ++lcol)
1297 for(uword lrow = 0; lrow < n_rows; ++lrow)
1298 {
1299 at(lrow, lcol) -= x.at(lrow, lcol);
1300 }
1301
1302 return *this;
1303 }
1304
1305
1306
1307 template<typename eT>
1308 inline
1309 const SpMat<eT>&
1310 SpMat<eT>::operator*=(const subview<eT>& y)
1311 {
1312 arma_extra_debug_sigprint();
1313
1314 arma_debug_assert_mul_size(n_rows, n_cols, y.n_rows, y.n_cols, "matrix multiplication");
1315
1316 SpMat<eT> z(n_rows, y.n_cols);
1317
1318 // Performed in the same fashion as operator*=(SpMat).
1319 for (const_row_iterator x_row_it = begin_row(); x_row_it.pos() < n_nonzero; ++x_row_it)
1320 {
1321 for (uword lcol = 0; lcol < y.n_cols; ++lcol)
1322 {
1323 // At this moment in the loop, we are calculating anything that is contributed to by *x_row_it and *y_col_it.
1324 // Given that our position is x_ab and y_bc, there will only be a contribution if x.col == y.row, and that
1325 // contribution will be in location z_ac.
1326 z.at(x_row_it.row, lcol) += (*x_row_it) * y.at(x_row_it.col, lcol);
1327 }
1328 }
1329
1330 steal_mem(z);
1331
1332 return *this;
1333 }
1334
1335
1336
1337 template<typename eT>
1338 inline
1339 const SpMat<eT>&
1340 SpMat<eT>::operator%=(const subview<eT>& x)
1341 {
1342 arma_extra_debug_sigprint();
1343
1344 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise multiplication");
1345
1346 // Loop over every element.
1347 for(uword lcol = 0; lcol < n_cols; ++lcol)
1348 for(uword lrow = 0; lrow < n_rows; ++lrow)
1349 {
1350 at(lrow, lcol) *= x.at(lrow, lcol);
1351 }
1352
1353 return *this;
1354 }
1355
1356
1357
1358 template<typename eT>
1359 inline
1360 const SpMat<eT>&
1361 SpMat<eT>::operator/=(const subview<eT>& x)
1362 {
1363 arma_extra_debug_sigprint();
1364
1365 arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division");
1366
1367 // Loop over every element.
1368 for(uword lcol = 0; lcol < n_cols; ++lcol)
1369 for(uword lrow = 0; lrow < n_rows; ++lrow)
1370 {
1371 at(lrow, lcol) /= x.at(lrow, lcol);
1372 }
1373
1374 return *this;
1375 }
1376
1377
1378
1379 template<typename eT>
1380 template<typename T1, typename spop_type>
1381 inline
1382 SpMat<eT>::SpMat(const SpOp<T1, spop_type>& X)
1383 : n_rows(0)
1384 , n_cols(0)
1385 , n_elem(0)
1386 , n_nonzero(0)
1387 , vec_state(0)
1388 , values(NULL) // set in application of sparse operation
1389 , row_indices(NULL)
1390 , col_ptrs(NULL)
1391 {
1392 arma_extra_debug_sigprint_this(this);
1393
1394 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1395
1396 spop_type::apply(*this, X);
1397 }
1398
1399
1400
1401 template<typename eT>
1402 template<typename T1, typename spop_type>
1403 inline
1404 const SpMat<eT>&
1405 SpMat<eT>::operator=(const SpOp<T1, spop_type>& X)
1406 {
1407 arma_extra_debug_sigprint();
1408
1409 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1410
1411 spop_type::apply(*this, X);
1412
1413 return *this;
1414 }
1415
1416
1417
1418 template<typename eT>
1419 template<typename T1, typename spop_type>
1420 inline
1421 const SpMat<eT>&
1422 SpMat<eT>::operator+=(const SpOp<T1, spop_type>& X)
1423 {
1424 arma_extra_debug_sigprint();
1425
1426 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1427
1428 const SpMat<eT> m(X);
1429
1430 return (*this).operator+=(m);
1431 }
1432
1433
1434
1435 template<typename eT>
1436 template<typename T1, typename spop_type>
1437 inline
1438 const SpMat<eT>&
1439 SpMat<eT>::operator-=(const SpOp<T1, spop_type>& X)
1440 {
1441 arma_extra_debug_sigprint();
1442
1443 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1444
1445 const SpMat<eT> m(X);
1446
1447 return (*this).operator-=(m);
1448 }
1449
1450
1451
1452 template<typename eT>
1453 template<typename T1, typename spop_type>
1454 inline
1455 const SpMat<eT>&
1456 SpMat<eT>::operator*=(const SpOp<T1, spop_type>& X)
1457 {
1458 arma_extra_debug_sigprint();
1459
1460 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1461
1462 const SpMat<eT> m(X);
1463
1464 return (*this).operator*=(m);
1465 }
1466
1467
1468
1469 template<typename eT>
1470 template<typename T1, typename spop_type>
1471 inline
1472 const SpMat<eT>&
1473 SpMat<eT>::operator%=(const SpOp<T1, spop_type>& X)
1474 {
1475 arma_extra_debug_sigprint();
1476
1477 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1478
1479 const SpMat<eT> m(X);
1480
1481 return (*this).operator%=(m);
1482 }
1483
1484
1485
1486 template<typename eT>
1487 template<typename T1, typename spop_type>
1488 inline
1489 const SpMat<eT>&
1490 SpMat<eT>::operator/=(const SpOp<T1, spop_type>& X)
1491 {
1492 arma_extra_debug_sigprint();
1493
1494 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1495
1496 const SpMat<eT> m(X);
1497
1498 return (*this).operator/=(m);
1499 }
1500
1501
1502
1503 template<typename eT>
1504 template<typename T1, typename T2, typename spglue_type>
1505 inline
1506 SpMat<eT>::SpMat(const SpGlue<T1, T2, spglue_type>& X)
1507 : n_rows(0)
1508 , n_cols(0)
1509 , n_elem(0)
1510 , n_nonzero(0)
1511 , vec_state(0)
1512 , values(NULL) // extra element set in application of sparse glue
1513 , row_indices(NULL)
1514 , col_ptrs(NULL)
1515 {
1516 arma_extra_debug_sigprint_this(this);
1517
1518 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1519
1520 spglue_type::apply(*this, X);
1521 }
1522
1523
1524
1525 template<typename eT>
1526 template<typename T1, typename spop_type>
1527 inline
1528 SpMat<eT>::SpMat(const mtSpOp<eT, T1, spop_type>& X)
1529 : n_rows(0)
1530 , n_cols(0)
1531 , n_elem(0)
1532 , n_nonzero(0)
1533 , vec_state(0)
1534 , values(NULL) // extra element set in application of sparse glue
1535 , row_indices(NULL)
1536 , col_ptrs(NULL)
1537 {
1538 arma_extra_debug_sigprint_this(this);
1539
1540 spop_type::apply(*this, X);
1541 }
1542
1543
1544
1545 template<typename eT>
1546 template<typename T1, typename spop_type>
1547 inline
1548 const SpMat<eT>&
1549 SpMat<eT>::operator=(const mtSpOp<eT, T1, spop_type>& X)
1550 {
1551 arma_extra_debug_sigprint();
1552
1553 spop_type::apply(*this, X);
1554
1555 return *this;
1556 }
1557
1558
1559
1560 template<typename eT>
1561 template<typename T1, typename spop_type>
1562 inline
1563 const SpMat<eT>&
1564 SpMat<eT>::operator+=(const mtSpOp<eT, T1, spop_type>& X)
1565 {
1566 arma_extra_debug_sigprint();
1567
1568 const SpMat<eT> m(X);
1569
1570 return (*this).operator+=(m);
1571 }
1572
1573
1574
1575 template<typename eT>
1576 template<typename T1, typename spop_type>
1577 inline
1578 const SpMat<eT>&
1579 SpMat<eT>::operator-=(const mtSpOp<eT, T1, spop_type>& X)
1580 {
1581 arma_extra_debug_sigprint();
1582
1583 const SpMat<eT> m(X);
1584
1585 return (*this).operator-=(m);
1586 }
1587
1588
1589
1590 template<typename eT>
1591 template<typename T1, typename spop_type>
1592 inline
1593 const SpMat<eT>&
1594 SpMat<eT>::operator*=(const mtSpOp<eT, T1, spop_type>& X)
1595 {
1596 arma_extra_debug_sigprint();
1597
1598 const SpMat<eT> m(X);
1599
1600 return (*this).operator*=(m);
1601 }
1602
1603
1604
1605 template<typename eT>
1606 template<typename T1, typename spop_type>
1607 inline
1608 const SpMat<eT>&
1609 SpMat<eT>::operator%=(const mtSpOp<eT, T1, spop_type>& X)
1610 {
1611 arma_extra_debug_sigprint();
1612
1613 const SpMat<eT> m(X);
1614
1615 return (*this).operator%=(m);
1616 }
1617
1618
1619
1620 template<typename eT>
1621 template<typename T1, typename spop_type>
1622 inline
1623 const SpMat<eT>&
1624 SpMat<eT>::operator/=(const mtSpOp<eT, T1, spop_type>& X)
1625 {
1626 arma_extra_debug_sigprint();
1627
1628 const SpMat<eT> m(X);
1629
1630 return (*this).operator/=(m);
1631 }
1632
1633
1634
1635 template<typename eT>
1636 template<typename T1, typename T2, typename spglue_type>
1637 inline
1638 const SpMat<eT>&
1639 SpMat<eT>::operator=(const SpGlue<T1, T2, spglue_type>& X)
1640 {
1641 arma_extra_debug_sigprint();
1642
1643 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1644
1645 spglue_type::apply(*this, X);
1646
1647 return *this;
1648 }
1649
1650
1651
1652 template<typename eT>
1653 template<typename T1, typename T2, typename spglue_type>
1654 inline
1655 const SpMat<eT>&
1656 SpMat<eT>::operator+=(const SpGlue<T1, T2, spglue_type>& X)
1657 {
1658 arma_extra_debug_sigprint();
1659
1660 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1661
1662 const SpMat<eT> m(X);
1663
1664 return (*this).operator+=(m);
1665 }
1666
1667
1668
1669 template<typename eT>
1670 template<typename T1, typename T2, typename spglue_type>
1671 inline
1672 const SpMat<eT>&
1673 SpMat<eT>::operator-=(const SpGlue<T1, T2, spglue_type>& X)
1674 {
1675 arma_extra_debug_sigprint();
1676
1677 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1678
1679 const SpMat<eT> m(X);
1680
1681 return (*this).operator-=(m);
1682 }
1683
1684
1685
1686 template<typename eT>
1687 template<typename T1, typename T2, typename spglue_type>
1688 inline
1689 const SpMat<eT>&
1690 SpMat<eT>::operator*=(const SpGlue<T1, T2, spglue_type>& X)
1691 {
1692 arma_extra_debug_sigprint();
1693
1694 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1695
1696 const SpMat<eT> m(X);
1697
1698 return (*this).operator*=(m);
1699 }
1700
1701
1702
1703 template<typename eT>
1704 template<typename T1, typename T2, typename spglue_type>
1705 inline
1706 const SpMat<eT>&
1707 SpMat<eT>::operator%=(const SpGlue<T1, T2, spglue_type>& X)
1708 {
1709 arma_extra_debug_sigprint();
1710
1711 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1712
1713 const SpMat<eT> m(X);
1714
1715 return (*this).operator%=(m);
1716 }
1717
1718
1719
1720 template<typename eT>
1721 template<typename T1, typename T2, typename spglue_type>
1722 inline
1723 const SpMat<eT>&
1724 SpMat<eT>::operator/=(const SpGlue<T1, T2, spglue_type>& X)
1725 {
1726 arma_extra_debug_sigprint();
1727
1728 arma_type_check(( is_same_type< eT, typename T1::elem_type >::value == false ));
1729
1730 const SpMat<eT> m(X);
1731
1732 return (*this).operator/=(m);
1733 }
1734
1735
1736
1737 template<typename eT>
1738 arma_inline
1739 SpSubview<eT>
1740 SpMat<eT>::row(const uword row_num)
1741 {
1742 arma_extra_debug_sigprint();
1743
1744 arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds");
1745
1746 return SpSubview<eT>(*this, row_num, 0, 1, n_cols);
1747 }
1748
1749
1750
1751 template<typename eT>
1752 arma_inline
1753 const SpSubview<eT>
1754 SpMat<eT>::row(const uword row_num) const
1755 {
1756 arma_extra_debug_sigprint();
1757
1758 arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds");
1759
1760 return SpSubview<eT>(*this, row_num, 0, 1, n_cols);
1761 }
1762
1763
1764
1765 template<typename eT>
1766 inline
1767 SpSubview<eT>
1768 SpMat<eT>::operator()(const uword row_num, const span& col_span)
1769 {
1770 arma_extra_debug_sigprint();
1771
1772 const bool col_all = col_span.whole;
1773
1774 const uword local_n_cols = n_cols;
1775
1776 const uword in_col1 = col_all ? 0 : col_span.a;
1777 const uword in_col2 = col_span.b;
1778 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
1779
1780 arma_debug_check
1781 (
1782 (row_num >= n_rows)
1783 ||
1784 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
1785 ,
1786 "SpMat::operator(): indices out of bounds or incorrectly used"
1787 );
1788
1789 return SpSubview<eT>(*this, row_num, in_col1, 1, submat_n_cols);
1790 }
1791
1792
1793
1794 template<typename eT>
1795 inline
1796 const SpSubview<eT>
1797 SpMat<eT>::operator()(const uword row_num, const span& col_span) const
1798 {
1799 arma_extra_debug_sigprint();
1800
1801 const bool col_all = col_span.whole;
1802
1803 const uword local_n_cols = n_cols;
1804
1805 const uword in_col1 = col_all ? 0 : col_span.a;
1806 const uword in_col2 = col_span.b;
1807 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
1808
1809 arma_debug_check
1810 (
1811 (row_num >= n_rows)
1812 ||
1813 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
1814 ,
1815 "SpMat::operator(): indices out of bounds or incorrectly used"
1816 );
1817
1818 return SpSubview<eT>(*this, row_num, in_col1, 1, submat_n_cols);
1819 }
1820
1821
1822
1823 template<typename eT>
1824 arma_inline
1825 SpSubview<eT>
1826 SpMat<eT>::col(const uword col_num)
1827 {
1828 arma_extra_debug_sigprint();
1829
1830 arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds");
1831
1832 return SpSubview<eT>(*this, 0, col_num, n_rows, 1);
1833 }
1834
1835
1836
1837 template<typename eT>
1838 arma_inline
1839 const SpSubview<eT>
1840 SpMat<eT>::col(const uword col_num) const
1841 {
1842 arma_extra_debug_sigprint();
1843
1844 arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds");
1845
1846 return SpSubview<eT>(*this, 0, col_num, n_rows, 1);
1847 }
1848
1849
1850
1851 template<typename eT>
1852 inline
1853 SpSubview<eT>
1854 SpMat<eT>::operator()(const span& row_span, const uword col_num)
1855 {
1856 arma_extra_debug_sigprint();
1857
1858 const bool row_all = row_span.whole;
1859
1860 const uword local_n_rows = n_rows;
1861
1862 const uword in_row1 = row_all ? 0 : row_span.a;
1863 const uword in_row2 = row_span.b;
1864 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
1865
1866 arma_debug_check
1867 (
1868 (col_num >= n_cols)
1869 ||
1870 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
1871 ,
1872 "SpMat::operator(): indices out of bounds or incorrectly used"
1873 );
1874
1875 return SpSubview<eT>(*this, in_row1, col_num, submat_n_rows, 1);
1876 }
1877
1878
1879
1880 template<typename eT>
1881 inline
1882 const SpSubview<eT>
1883 SpMat<eT>::operator()(const span& row_span, const uword col_num) const
1884 {
1885 arma_extra_debug_sigprint();
1886
1887 const bool row_all = row_span.whole;
1888
1889 const uword local_n_rows = n_rows;
1890
1891 const uword in_row1 = row_all ? 0 : row_span.a;
1892 const uword in_row2 = row_span.b;
1893 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
1894
1895 arma_debug_check
1896 (
1897 (col_num >= n_cols)
1898 ||
1899 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
1900 ,
1901 "SpMat::operator(): indices out of bounds or incorrectly used"
1902 );
1903
1904 return SpSubview<eT>(*this, in_row1, col_num, submat_n_rows, 1);
1905 }
1906
1907
1908
1909 /**
1910 * Swap in_row1 with in_row2.
1911 */
1912 template<typename eT>
1913 inline
1914 void
1915 SpMat<eT>::swap_rows(const uword in_row1, const uword in_row2)
1916 {
1917 arma_extra_debug_sigprint();
1918
1919 arma_debug_check
1920 (
1921 (in_row1 >= n_rows) || (in_row2 >= n_rows),
1922 "SpMat::swap_rows(): out of bounds"
1923 );
1924
1925 // Sanity check.
1926 if (in_row1 == in_row2)
1927 {
1928 return;
1929 }
1930
1931 // The easier way to do this, instead of collecting all the elements in one row and then swapping with the other, will be
1932 // to iterate over each column of the matrix (since we store in column-major format) and then swap the two elements in the two rows at that time.
1933 // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position.
1934 uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2;
1935 uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1;
1936
1937 for (uword lcol = 0; lcol < n_cols; lcol++)
1938 {
1939 // If there is nothing in this column we can ignore it.
1940 if (col_ptrs[lcol] == col_ptrs[lcol + 1])
1941 {
1942 continue;
1943 }
1944
1945 // These will represent the positions of the items themselves.
1946 uword loc1 = n_nonzero + 1;
1947 uword loc2 = n_nonzero + 1;
1948
1949 for (uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++)
1950 {
1951 if (row_indices[search_pos] == col1)
1952 {
1953 loc1 = search_pos;
1954 }
1955
1956 if (row_indices[search_pos] == col2)
1957 {
1958 loc2 = search_pos;
1959 break; // No need to look any further.
1960 }
1961 }
1962
1963 // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements.
1964 // If we found zero elements no work needs to be done and we can continue to the next column.
1965 if ((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1)))
1966 {
1967 // This is an easy case: just swap the values. No index modifying necessary.
1968 eT tmp = values[loc1];
1969 access::rw(values[loc1]) = values[loc2];
1970 access::rw(values[loc2]) = tmp;
1971 }
1972 else if (loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2.
1973 {
1974 // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1.
1975 // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be.
1976 while (((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2))
1977 {
1978 // Swap both the values and the indices. The column should not change.
1979 eT tmp = values[loc1];
1980 access::rw(values[loc1]) = values[loc1 + 1];
1981 access::rw(values[loc1 + 1]) = tmp;
1982
1983 uword tmp_index = row_indices[loc1];
1984 access::rw(row_indices[loc1]) = row_indices[loc1 + 1];
1985 access::rw(row_indices[loc1 + 1]) = tmp_index;
1986
1987 loc1++; // And increment the counter.
1988 }
1989
1990 // Now set the row index correctly.
1991 access::rw(row_indices[loc1]) = in_row2;
1992
1993 }
1994 else if (loc2 != (n_nonzero + 1))
1995 {
1996 // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2.
1997 // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be.
1998 while (((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1))
1999 {
2000 // Swap both the values and the indices. The column should not change.
2001 eT tmp = values[loc2];
2002 access::rw(values[loc2]) = values[loc2 - 1];
2003 access::rw(values[loc2 - 1]) = tmp;
2004
2005 uword tmp_index = row_indices[loc2];
2006 access::rw(row_indices[loc2]) = row_indices[loc2 - 1];
2007 access::rw(row_indices[loc2 - 1]) = tmp_index;
2008
2009 loc2--; // And decrement the counter.
2010 }
2011
2012 // Now set the row index correctly.
2013 access::rw(row_indices[loc2]) = in_row1;
2014
2015 }
2016 /* else: no need to swap anything; both values are zero */
2017 }
2018 }
2019
2020 /**
2021 * Swap in_col1 with in_col2.
2022 */
2023 template<typename eT>
2024 inline
2025 void
2026 SpMat<eT>::swap_cols(const uword in_col1, const uword in_col2)
2027 {
2028 arma_extra_debug_sigprint();
2029
2030 // slow but works
2031 for(uword lrow = 0; lrow < n_rows; ++lrow)
2032 {
2033 eT tmp = at(lrow, in_col1);
2034 at(lrow, in_col1) = at(lrow, in_col2);
2035 at(lrow, in_col2) = tmp;
2036 }
2037 }
2038
2039 /**
2040 * Remove the row row_num.
2041 */
2042 template<typename eT>
2043 inline
2044 void
2045 SpMat<eT>::shed_row(const uword row_num)
2046 {
2047 arma_extra_debug_sigprint();
2048 arma_debug_check (row_num >= n_rows, "SpMat::shed_row(): out of bounds");
2049
2050 shed_rows (row_num, row_num);
2051 }
2052
2053 /**
2054 * Remove the column col_num.
2055 */
2056 template<typename eT>
2057 inline
2058 void
2059 SpMat<eT>::shed_col(const uword col_num)
2060 {
2061 arma_extra_debug_sigprint();
2062 arma_debug_check (col_num >= n_cols, "SpMat::shed_col(): out of bounds");
2063
2064 shed_cols(col_num, col_num);
2065 }
2066
2067 /**
2068 * Remove all rows between (and including) in_row1 and in_row2.
2069 */
2070 template<typename eT>
2071 inline
2072 void
2073 SpMat<eT>::shed_rows(const uword in_row1, const uword in_row2)
2074 {
2075 arma_extra_debug_sigprint();
2076
2077 arma_debug_check
2078 (
2079 (in_row1 > in_row2) || (in_row2 >= n_rows),
2080 "SpMat::shed_rows(): indices out of bounds or incorectly used"
2081 );
2082
2083 uword i, j;
2084 // Store the length of values
2085 uword vlength = n_nonzero;
2086 // Store the length of col_ptrs
2087 uword clength = n_cols + 1;
2088
2089 // This is O(n * n_cols) and inplace, there may be a faster way, though.
2090 for (i = 0, j = 0; i < vlength; ++i)
2091 {
2092 // Store the row of the ith element.
2093 const uword lrow = row_indices[i];
2094 // Is the ith element in the range of rows we want to remove?
2095 if (lrow >= in_row1 && lrow <= in_row2)
2096 {
2097 // Increment our "removed elements" counter.
2098 ++j;
2099
2100 // Adjust the values of col_ptrs each time we remove an element.
2101 // Basically, the length of one column reduces by one, and everything to
2102 // its right gets reduced by one to represent all the elements being
2103 // shifted to the left by one.
2104 for(uword k = 0; k < clength; ++k)
2105 {
2106 if (col_ptrs[k] > (i - j + 1))
2107 {
2108 --access::rw(col_ptrs[k]);
2109 }
2110 }
2111 }
2112 else
2113 {
2114 // We shift the element we checked to the left by how many elements
2115 // we have removed.
2116 // j = 0 until we remove the first element.
2117 if (j != 0)
2118 {
2119 access::rw(row_indices[i - j]) = (lrow > in_row2) ? (lrow - (in_row2 - in_row1 + 1)) : lrow;
2120 access::rw(values[i - j]) = values[i];
2121 }
2122 }
2123 }
2124
2125 // j is the number of elements removed.
2126
2127 // Shrink the vectors. This will copy the memory.
2128 mem_resize(n_nonzero - j);
2129
2130 // Adjust row and element counts.
2131 access::rw(n_rows) = n_rows - (in_row2 - in_row1) - 1;
2132 access::rw(n_elem) = n_rows * n_cols;
2133 }
2134
2135 /**
2136 * Remove all columns between (and including) in_col1 and in_col2.
2137 */
2138 template<typename eT>
2139 inline
2140 void
2141 SpMat<eT>::shed_cols(const uword in_col1, const uword in_col2)
2142 {
2143 arma_extra_debug_sigprint();
2144
2145 arma_debug_check
2146 (
2147 (in_col1 > in_col2) || (in_col2 >= n_cols),
2148 "SpMat::shed_cols(): indices out of bounds or incorrectly used"
2149 );
2150
2151 // First we find the locations in values and row_indices for the column entries.
2152 uword col_beg = col_ptrs[in_col1];
2153 uword col_end = col_ptrs[in_col2 + 1];
2154
2155 // Then we find the number of entries in the column.
2156 uword diff = col_end - col_beg;
2157
2158 if (diff > 0)
2159 {
2160 eT* new_values = memory::acquire_chunked<eT> (n_nonzero - diff);
2161 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero - diff);
2162
2163 // Copy first part.
2164 if (col_beg != 0)
2165 {
2166 arrayops::copy(new_values, values, col_beg);
2167 arrayops::copy(new_row_indices, row_indices, col_beg);
2168 }
2169
2170 // Copy second part.
2171 if (col_end != n_nonzero)
2172 {
2173 arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end);
2174 arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end);
2175 }
2176
2177 memory::release(values);
2178 memory::release(row_indices);
2179
2180 access::rw(values) = new_values;
2181 access::rw(row_indices) = new_row_indices;
2182
2183 // Update counts and such.
2184 access::rw(n_nonzero) -= diff;
2185 }
2186
2187 // Update column pointers.
2188 const uword new_n_cols = n_cols - ((in_col2 - in_col1) + 1);
2189
2190 uword* new_col_ptrs = memory::acquire<uword>(new_n_cols + 2);
2191 new_col_ptrs[new_n_cols + 1] = std::numeric_limits<uword>::max();
2192
2193 // Copy first set of columns (no manipulation required).
2194 if (in_col1 != 0)
2195 {
2196 arrayops::copy(new_col_ptrs, col_ptrs, in_col1);
2197 }
2198
2199 // Copy second set of columns (manipulation required).
2200 uword cur_col = in_col1;
2201 for (uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col)
2202 {
2203 new_col_ptrs[cur_col] = col_ptrs[i] - diff;
2204 }
2205
2206 memory::release(col_ptrs);
2207 access::rw(col_ptrs) = new_col_ptrs;
2208
2209 // We update the element and column counts, and we're done.
2210 access::rw(n_cols) = new_n_cols;
2211 access::rw(n_elem) = n_cols * n_rows;
2212 }
2213
2214
2215
2216 template<typename eT>
2217 arma_inline
2218 SpSubview<eT>
2219 SpMat<eT>::rows(const uword in_row1, const uword in_row2)
2220 {
2221 arma_extra_debug_sigprint();
2222
2223 arma_debug_check
2224 (
2225 (in_row1 > in_row2) || (in_row2 >= n_rows),
2226 "SpMat::rows(): indices out of bounds or incorrectly used"
2227 );
2228
2229 const uword subview_n_rows = in_row2 - in_row1 + 1;
2230
2231 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
2232 }
2233
2234
2235
2236 template<typename eT>
2237 arma_inline
2238 const SpSubview<eT>
2239 SpMat<eT>::rows(const uword in_row1, const uword in_row2) const
2240 {
2241 arma_extra_debug_sigprint();
2242
2243 arma_debug_check
2244 (
2245 (in_row1 > in_row2) || (in_row2 >= n_rows),
2246 "SpMat::rows(): indices out of bounds or incorrectly used"
2247 );
2248
2249 const uword subview_n_rows = in_row2 - in_row1 + 1;
2250
2251 return SpSubview<eT>(*this, in_row1, 0, subview_n_rows, n_cols);
2252 }
2253
2254
2255
2256 template<typename eT>
2257 arma_inline
2258 SpSubview<eT>
2259 SpMat<eT>::cols(const uword in_col1, const uword in_col2)
2260 {
2261 arma_extra_debug_sigprint();
2262
2263 arma_debug_check
2264 (
2265 (in_col1 > in_col2) || (in_col2 >= n_cols),
2266 "SpMat::cols(): indices out of bounds or incorrectly used"
2267 );
2268
2269 const uword subview_n_cols = in_col2 - in_col1 + 1;
2270
2271 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
2272 }
2273
2274
2275
2276 template<typename eT>
2277 arma_inline
2278 const SpSubview<eT>
2279 SpMat<eT>::cols(const uword in_col1, const uword in_col2) const
2280 {
2281 arma_extra_debug_sigprint();
2282
2283 arma_debug_check
2284 (
2285 (in_col1 > in_col2) || (in_col2 >= n_cols),
2286 "SpMat::cols(): indices out of bounds or incorrectly used"
2287 );
2288
2289 const uword subview_n_cols = in_col2 - in_col1 + 1;
2290
2291 return SpSubview<eT>(*this, 0, in_col1, n_rows, subview_n_cols);
2292 }
2293
2294
2295
2296 template<typename eT>
2297 arma_inline
2298 SpSubview<eT>
2299 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2)
2300 {
2301 arma_extra_debug_sigprint();
2302
2303 arma_debug_check
2304 (
2305 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
2306 "SpMat::submat(): indices out of bounds or incorrectly used"
2307 );
2308
2309 const uword subview_n_rows = in_row2 - in_row1 + 1;
2310 const uword subview_n_cols = in_col2 - in_col1 + 1;
2311
2312 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
2313 }
2314
2315
2316
2317 template<typename eT>
2318 arma_inline
2319 const SpSubview<eT>
2320 SpMat<eT>::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const
2321 {
2322 arma_extra_debug_sigprint();
2323
2324 arma_debug_check
2325 (
2326 (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols),
2327 "SpMat::submat(): indices out of bounds or incorrectly used"
2328 );
2329
2330 const uword subview_n_rows = in_row2 - in_row1 + 1;
2331 const uword subview_n_cols = in_col2 - in_col1 + 1;
2332
2333 return SpSubview<eT>(*this, in_row1, in_col1, subview_n_rows, subview_n_cols);
2334 }
2335
2336
2337
2338 template<typename eT>
2339 inline
2340 SpSubview<eT>
2341 SpMat<eT>::submat (const span& row_span, const span& col_span)
2342 {
2343 arma_extra_debug_sigprint();
2344
2345 const bool row_all = row_span.whole;
2346 const bool col_all = col_span.whole;
2347
2348 const uword local_n_rows = n_rows;
2349 const uword local_n_cols = n_cols;
2350
2351 const uword in_row1 = row_all ? 0 : row_span.a;
2352 const uword in_row2 = row_span.b;
2353 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2354
2355 const uword in_col1 = col_all ? 0 : col_span.a;
2356 const uword in_col2 = col_span.b;
2357 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2358
2359 arma_debug_check
2360 (
2361 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2362 ||
2363 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2364 ,
2365 "SpMat::submat(): indices out of bounds or incorrectly used"
2366 );
2367
2368 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
2369 }
2370
2371
2372
2373 template<typename eT>
2374 inline
2375 const SpSubview<eT>
2376 SpMat<eT>::submat (const span& row_span, const span& col_span) const
2377 {
2378 arma_extra_debug_sigprint();
2379
2380 const bool row_all = row_span.whole;
2381 const bool col_all = col_span.whole;
2382
2383 const uword local_n_rows = n_rows;
2384 const uword local_n_cols = n_cols;
2385
2386 const uword in_row1 = row_all ? 0 : row_span.a;
2387 const uword in_row2 = row_span.b;
2388 const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1;
2389
2390 const uword in_col1 = col_all ? 0 : col_span.a;
2391 const uword in_col2 = col_span.b;
2392 const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1;
2393
2394 arma_debug_check
2395 (
2396 ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) )
2397 ||
2398 ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) )
2399 ,
2400 "SpMat::submat(): indices out of bounds or incorrectly used"
2401 );
2402
2403 return SpSubview<eT>(*this, in_row1, in_col1, submat_n_rows, submat_n_cols);
2404 }
2405
2406
2407
2408 template<typename eT>
2409 inline
2410 SpSubview<eT>
2411 SpMat<eT>::operator()(const span& row_span, const span& col_span)
2412 {
2413 arma_extra_debug_sigprint();
2414
2415 return submat(row_span, col_span);
2416 }
2417
2418
2419
2420 template<typename eT>
2421 inline
2422 const SpSubview<eT>
2423 SpMat<eT>::operator()(const span& row_span, const span& col_span) const
2424 {
2425 arma_extra_debug_sigprint();
2426
2427 return submat(row_span, col_span);
2428 }
2429
2430
2431
2432 /**
2433 * Element access; acces the i'th element (works identically to the Mat accessors).
2434 * If there is nothing at element i, 0 is returned.
2435 *
2436 * @param i Element to access.
2437 */
2438
2439 template<typename eT>
2440 arma_inline
2441 arma_warn_unused
2442 SpValProxy<SpMat<eT> >
2443 SpMat<eT>::operator[](const uword i)
2444 {
2445 return get_value(i);
2446 }
2447
2448
2449
2450 template<typename eT>
2451 arma_inline
2452 arma_warn_unused
2453 eT
2454 SpMat<eT>::operator[](const uword i) const
2455 {
2456 return get_value(i);
2457 }
2458
2459
2460
2461 template<typename eT>
2462 arma_inline
2463 arma_warn_unused
2464 SpValProxy<SpMat<eT> >
2465 SpMat<eT>::at(const uword i)
2466 {
2467 return get_value(i);
2468 }
2469
2470
2471
2472 template<typename eT>
2473 arma_inline
2474 arma_warn_unused
2475 eT
2476 SpMat<eT>::at(const uword i) const
2477 {
2478 return get_value(i);
2479 }
2480
2481
2482
2483 template<typename eT>
2484 arma_inline
2485 arma_warn_unused
2486 SpValProxy<SpMat<eT> >
2487 SpMat<eT>::operator()(const uword i)
2488 {
2489 arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds");
2490 return get_value(i);
2491 }
2492
2493
2494
2495 template<typename eT>
2496 arma_inline
2497 arma_warn_unused
2498 eT
2499 SpMat<eT>::operator()(const uword i) const
2500 {
2501 arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds");
2502 return get_value(i);
2503 }
2504
2505
2506
2507 /**
2508 * Element access; access the element at row in_rows and column in_col.
2509 * If there is nothing at that position, 0 is returned.
2510 */
2511
2512 template<typename eT>
2513 arma_inline
2514 arma_warn_unused
2515 SpValProxy<SpMat<eT> >
2516 SpMat<eT>::at(const uword in_row, const uword in_col)
2517 {
2518 return get_value(in_row, in_col);
2519 }
2520
2521
2522
2523 template<typename eT>
2524 arma_inline
2525 arma_warn_unused
2526 eT
2527 SpMat<eT>::at(const uword in_row, const uword in_col) const
2528 {
2529 return get_value(in_row, in_col);
2530 }
2531
2532
2533
2534 template<typename eT>
2535 arma_inline
2536 arma_warn_unused
2537 SpValProxy<SpMat<eT> >
2538 SpMat<eT>::operator()(const uword in_row, const uword in_col)
2539 {
2540 arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds");
2541 return get_value(in_row, in_col);
2542 }
2543
2544
2545
2546 template<typename eT>
2547 arma_inline
2548 arma_warn_unused
2549 eT
2550 SpMat<eT>::operator()(const uword in_row, const uword in_col) const
2551 {
2552 arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds");
2553 return get_value(in_row, in_col);
2554 }
2555
2556
2557
2558 /**
2559 * Check if matrix is empty (no size, no values).
2560 */
2561 template<typename eT>
2562 arma_inline
2563 arma_warn_unused
2564 bool
2565 SpMat<eT>::is_empty() const
2566 {
2567 return(n_elem == 0);
2568 }
2569
2570
2571
2572 //! returns true if the object can be interpreted as a column or row vector
2573 template<typename eT>
2574 arma_inline
2575 arma_warn_unused
2576 bool
2577 SpMat<eT>::is_vec() const
2578 {
2579 return ( (n_rows == 1) || (n_cols == 1) );
2580 }
2581
2582
2583
2584 //! returns true if the object can be interpreted as a row vector
2585 template<typename eT>
2586 arma_inline
2587 arma_warn_unused
2588 bool
2589 SpMat<eT>::is_rowvec() const
2590 {
2591 return (n_rows == 1);
2592 }
2593
2594
2595
2596 //! returns true if the object can be interpreted as a column vector
2597 template<typename eT>
2598 arma_inline
2599 arma_warn_unused
2600 bool
2601 SpMat<eT>::is_colvec() const
2602 {
2603 return (n_cols == 1);
2604 }
2605
2606
2607
2608 //! returns true if the object has the same number of non-zero rows and columnns
2609 template<typename eT>
2610 arma_inline
2611 arma_warn_unused
2612 bool
2613 SpMat<eT>::is_square() const
2614 {
2615 return (n_rows == n_cols);
2616 }
2617
2618
2619
2620 //! returns true if all of the elements are finite
2621 template<typename eT>
2622 inline
2623 arma_warn_unused
2624 bool
2625 SpMat<eT>::is_finite() const
2626 {
2627 for(uword i = 0; i < n_nonzero; i++)
2628 {
2629 if(arma_isfinite(values[i]) == false)
2630 {
2631 return false;
2632 }
2633 }
2634
2635 return true; // No infinite values.
2636 }
2637
2638
2639
2640 //! returns true if the given index is currently in range
2641 template<typename eT>
2642 arma_inline
2643 arma_warn_unused
2644 bool
2645 SpMat<eT>::in_range(const uword i) const
2646 {
2647 return (i < n_elem);
2648 }
2649
2650
2651 //! returns true if the given start and end indices are currently in range
2652 template<typename eT>
2653 arma_inline
2654 arma_warn_unused
2655 bool
2656 SpMat<eT>::in_range(const span& x) const
2657 {
2658 arma_extra_debug_sigprint();
2659
2660 if(x.whole == true)
2661 {
2662 return true;
2663 }
2664 else
2665 {
2666 const uword a = x.a;
2667 const uword b = x.b;
2668
2669 return ( (a <= b) && (b < n_elem) );
2670 }
2671 }
2672
2673
2674
2675 //! returns true if the given location is currently in range
2676 template<typename eT>
2677 arma_inline
2678 arma_warn_unused
2679 bool
2680 SpMat<eT>::in_range(const uword in_row, const uword in_col) const
2681 {
2682 return ( (in_row < n_rows) && (in_col < n_cols) );
2683 }
2684
2685
2686
2687 template<typename eT>
2688 arma_inline
2689 arma_warn_unused
2690 bool
2691 SpMat<eT>::in_range(const span& row_span, const uword in_col) const
2692 {
2693 arma_extra_debug_sigprint();
2694
2695 if(row_span.whole == true)
2696 {
2697 return (in_col < n_cols);
2698 }
2699 else
2700 {
2701 const uword in_row1 = row_span.a;
2702 const uword in_row2 = row_span.b;
2703
2704 return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) );
2705 }
2706 }
2707
2708
2709
2710 template<typename eT>
2711 arma_inline
2712 arma_warn_unused
2713 bool
2714 SpMat<eT>::in_range(const uword in_row, const span& col_span) const
2715 {
2716 arma_extra_debug_sigprint();
2717
2718 if(col_span.whole == true)
2719 {
2720 return (in_row < n_rows);
2721 }
2722 else
2723 {
2724 const uword in_col1 = col_span.a;
2725 const uword in_col2 = col_span.b;
2726
2727 return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) );
2728 }
2729 }
2730
2731
2732
2733 template<typename eT>
2734 arma_inline
2735 arma_warn_unused
2736 bool
2737 SpMat<eT>::in_range(const span& row_span, const span& col_span) const
2738 {
2739 arma_extra_debug_sigprint();
2740
2741 const uword in_row1 = row_span.a;
2742 const uword in_row2 = row_span.b;
2743
2744 const uword in_col1 = col_span.a;
2745 const uword in_col2 = col_span.b;
2746
2747 const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) );
2748 const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) );
2749
2750 return ( (rows_ok == true) && (cols_ok == true) );
2751 }
2752
2753
2754
2755 template<typename eT>
2756 inline
2757 void
2758 SpMat<eT>::impl_print(const std::string& extra_text) const
2759 {
2760 arma_extra_debug_sigprint();
2761
2762 if(extra_text.length() != 0)
2763 {
2764 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width();
2765
2766 ARMA_DEFAULT_OSTREAM << extra_text << '\n';
2767
2768 ARMA_DEFAULT_OSTREAM.width(orig_width);
2769 }
2770
2771 arma_ostream::print(ARMA_DEFAULT_OSTREAM, *this, true);
2772 }
2773
2774
2775
2776 template<typename eT>
2777 inline
2778 void
2779 SpMat<eT>::impl_print(std::ostream& user_stream, const std::string& extra_text) const
2780 {
2781 arma_extra_debug_sigprint();
2782
2783 if(extra_text.length() != 0)
2784 {
2785 const std::streamsize orig_width = user_stream.width();
2786
2787 user_stream << extra_text << '\n';
2788
2789 user_stream.width(orig_width);
2790 }
2791
2792 arma_ostream::print(user_stream, *this, true);
2793 }
2794
2795
2796
2797 template<typename eT>
2798 inline
2799 void
2800 SpMat<eT>::impl_raw_print(const std::string& extra_text) const
2801 {
2802 arma_extra_debug_sigprint();
2803
2804 if(extra_text.length() != 0)
2805 {
2806 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width();
2807
2808 ARMA_DEFAULT_OSTREAM << extra_text << '\n';
2809
2810 ARMA_DEFAULT_OSTREAM.width(orig_width);
2811 }
2812
2813 arma_ostream::print(ARMA_DEFAULT_OSTREAM, *this, false);
2814 }
2815
2816
2817 template<typename eT>
2818 inline
2819 void
2820 SpMat<eT>::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const
2821 {
2822 arma_extra_debug_sigprint();
2823
2824 if(extra_text.length() != 0)
2825 {
2826 const std::streamsize orig_width = user_stream.width();
2827
2828 user_stream << extra_text << '\n';
2829
2830 user_stream.width(orig_width);
2831 }
2832
2833 arma_ostream::print(user_stream, *this, false);
2834 }
2835
2836
2837
2838 /**
2839 * Matrix printing, prepends supplied text.
2840 * Prints 0 wherever no element exists.
2841 */
2842 template<typename eT>
2843 inline
2844 void
2845 SpMat<eT>::impl_print_dense(const std::string& extra_text) const
2846 {
2847 arma_extra_debug_sigprint();
2848
2849 if(extra_text.length() != 0)
2850 {
2851 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width();
2852
2853 ARMA_DEFAULT_OSTREAM << extra_text << '\n';
2854
2855 ARMA_DEFAULT_OSTREAM.width(orig_width);
2856 }
2857
2858 arma_ostream::print_dense(ARMA_DEFAULT_OSTREAM, *this, true);
2859 }
2860
2861
2862
2863 template<typename eT>
2864 inline
2865 void
2866 SpMat<eT>::impl_print_dense(std::ostream& user_stream, const std::string& extra_text) const
2867 {
2868 arma_extra_debug_sigprint();
2869
2870 if(extra_text.length() != 0)
2871 {
2872 const std::streamsize orig_width = user_stream.width();
2873
2874 user_stream << extra_text << '\n';
2875
2876 user_stream.width(orig_width);
2877 }
2878
2879 arma_ostream::print_dense(user_stream, *this, true);
2880 }
2881
2882
2883
2884 template<typename eT>
2885 inline
2886 void
2887 SpMat<eT>::impl_raw_print_dense(const std::string& extra_text) const
2888 {
2889 arma_extra_debug_sigprint();
2890
2891 if(extra_text.length() != 0)
2892 {
2893 const std::streamsize orig_width = ARMA_DEFAULT_OSTREAM.width();
2894
2895 ARMA_DEFAULT_OSTREAM << extra_text << '\n';
2896
2897 ARMA_DEFAULT_OSTREAM.width(orig_width);
2898 }
2899
2900 arma_ostream::print_dense(ARMA_DEFAULT_OSTREAM, *this, false);
2901 }
2902
2903
2904
2905 template<typename eT>
2906 inline
2907 void
2908 SpMat<eT>::impl_raw_print_dense(std::ostream& user_stream, const std::string& extra_text) const
2909 {
2910 arma_extra_debug_sigprint();
2911
2912 if(extra_text.length() != 0)
2913 {
2914 const std::streamsize orig_width = user_stream.width();
2915
2916 user_stream << extra_text << '\n';
2917
2918 user_stream.width(orig_width);
2919 }
2920
2921 arma_ostream::print_dense(user_stream, *this, false);
2922 }
2923
2924
2925
2926 //! Set the size to the size of another matrix.
2927 template<typename eT>
2928 template<typename eT2>
2929 inline
2930 void
2931 SpMat<eT>::copy_size(const SpMat<eT2>& m)
2932 {
2933 arma_extra_debug_sigprint();
2934
2935 init(m.n_rows, m.n_cols);
2936 }
2937
2938
2939
2940 template<typename eT>
2941 template<typename eT2>
2942 inline
2943 void
2944 SpMat<eT>::copy_size(const Mat<eT2>& m)
2945 {
2946 arma_extra_debug_sigprint();
2947
2948 init(m.n_rows, m.n_cols);
2949 }
2950
2951
2952
2953 /**
2954 * Resize the matrix to a given size. The matrix will be resized to be a column vector (i.e. in_elem columns, 1 row).
2955 *
2956 * @param in_elem Number of elements to allow.
2957 */
2958 template<typename eT>
2959 inline
2960 void
2961 SpMat<eT>::set_size(const uword in_elem)
2962 {
2963 arma_extra_debug_sigprint();
2964
2965 // If this is a row vector, we resize to a row vector.
2966 if(vec_state == 2)
2967 {
2968 init(1, in_elem);
2969 }
2970 else
2971 {
2972 init(in_elem, 1);
2973 }
2974 }
2975
2976
2977
2978 /**
2979 * Resize the matrix to a given size.
2980 *
2981 * @param in_rows Number of rows to allow.
2982 * @param in_cols Number of columns to allow.
2983 */
2984 template<typename eT>
2985 inline
2986 void
2987 SpMat<eT>::set_size(const uword in_rows, const uword in_cols)
2988 {
2989 arma_extra_debug_sigprint();
2990
2991 init(in_rows, in_cols);
2992 }
2993
2994
2995
2996 template<typename eT>
2997 inline
2998 void
2999 SpMat<eT>::reshape(const uword in_rows, const uword in_cols, const uword dim)
3000 {
3001 arma_extra_debug_sigprint();
3002
3003 if (dim == 0)
3004 {
3005 // We have to modify all of the relevant row indices and the relevant column pointers.
3006 // Iterate over all the points to do this. We won't be deleting any points, but we will be modifying
3007 // columns and rows. We'll have to store a new set of column vectors.
3008 uword* new_col_ptrs = memory::acquire<uword>(in_cols + 2);
3009 new_col_ptrs[in_cols + 1] = std::numeric_limits<uword>::max();
3010
3011 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1);
3012 access::rw(new_row_indices[n_nonzero]) = 0;
3013
3014 arrayops::inplace_set(new_col_ptrs, uword(0), in_cols + 1);
3015
3016 for(const_iterator it = begin(); it != end(); it++)
3017 {
3018 uword vector_position = (it.col() * n_rows) + it.row();
3019 new_row_indices[it.pos()] = vector_position % in_rows;
3020 ++new_col_ptrs[vector_position / in_rows + 1];
3021 }
3022
3023 // Now sum the column counts to get the new column pointers.
3024 for(uword i = 1; i <= in_cols; i++)
3025 {
3026 access::rw(new_col_ptrs[i]) += new_col_ptrs[i - 1];
3027 }
3028
3029 // Copy the new row indices.
3030 memory::release(row_indices);
3031 access::rw(row_indices) = new_row_indices;
3032
3033 memory::release(col_ptrs);
3034 access::rw(col_ptrs) = new_col_ptrs;
3035
3036 // Now set the size.
3037 access::rw(n_rows) = in_rows;
3038 access::rw(n_cols) = in_cols;
3039 }
3040 else
3041 {
3042 // Row-wise reshaping. This is more tedious and we will use a separate sparse matrix to do it.
3043 SpMat<eT> tmp(in_rows, in_cols);
3044
3045 for(const_row_iterator it = begin_row(); it.pos() < n_nonzero; it++)
3046 {
3047 uword vector_position = (it.row() * n_cols) + it.col();
3048
3049 tmp((vector_position / in_cols), (vector_position % in_cols)) = (*it);
3050 }
3051
3052 (*this).operator=(tmp);
3053 }
3054 }
3055
3056
3057
3058 template<typename eT>
3059 inline
3060 const SpMat<eT>&
3061 SpMat<eT>::zeros()
3062 {
3063 arma_extra_debug_sigprint();
3064
3065 if (n_nonzero > 0)
3066 {
3067 memory::release(values);
3068 memory::release(row_indices);
3069
3070 access::rw(values) = memory::acquire_chunked<eT>(1);
3071 access::rw(row_indices) = memory::acquire_chunked<uword>(1);
3072
3073 access::rw(values[0]) = 0;
3074 access::rw(row_indices[0]) = 0;
3075 }
3076
3077 access::rw(n_nonzero) = 0;
3078 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), n_cols + 1);
3079
3080 return *this;
3081 }
3082
3083
3084
3085 template<typename eT>
3086 inline
3087 const SpMat<eT>&
3088 SpMat<eT>::zeros(const uword in_elem)
3089 {
3090 arma_extra_debug_sigprint();
3091
3092 if(vec_state == 2)
3093 {
3094 init(1, in_elem); // Row vector
3095 }
3096 else
3097 {
3098 init(in_elem, 1);
3099 }
3100
3101 return *this;
3102 }
3103
3104
3105
3106 template<typename eT>
3107 inline
3108 const SpMat<eT>&
3109 SpMat<eT>::zeros(const uword in_rows, const uword in_cols)
3110 {
3111 arma_extra_debug_sigprint();
3112
3113 init(in_rows, in_cols);
3114
3115 return *this;
3116 }
3117
3118
3119
3120 template<typename eT>
3121 inline
3122 const SpMat<eT>&
3123 SpMat<eT>::eye()
3124 {
3125 arma_extra_debug_sigprint();
3126
3127 return (*this).eye(n_rows, n_cols);
3128 }
3129
3130
3131
3132 template<typename eT>
3133 inline
3134 const SpMat<eT>&
3135 SpMat<eT>::eye(const uword in_rows, const uword in_cols)
3136 {
3137 arma_extra_debug_sigprint();
3138
3139 const uword N = (std::min)(in_rows, in_cols);
3140
3141 init(in_rows, in_cols);
3142
3143 mem_resize(N);
3144
3145 arrayops::inplace_set(access::rwp(values), eT(1), N);
3146
3147 for(uword i = 0; i < N; ++i) { access::rw(row_indices[i]) = i; }
3148
3149 for(uword i = 0; i <= N; ++i) { access::rw(col_ptrs[i]) = i; }
3150
3151 access::rw(n_nonzero) = N;
3152
3153 return *this;
3154 }
3155
3156
3157
3158 template<typename eT>
3159 inline
3160 const SpMat<eT>&
3161 SpMat<eT>::speye()
3162 {
3163 arma_extra_debug_sigprint();
3164
3165 return (*this).eye(n_rows, n_cols);
3166 }
3167
3168
3169
3170 template<typename eT>
3171 inline
3172 const SpMat<eT>&
3173 SpMat<eT>::speye(const uword in_n_rows, const uword in_n_cols)
3174 {
3175 arma_extra_debug_sigprint();
3176
3177 return (*this).eye(in_n_rows, in_n_cols);
3178 }
3179
3180
3181
3182 template<typename eT>
3183 inline
3184 const SpMat<eT>&
3185 SpMat<eT>::sprandu(const uword in_rows, const uword in_cols, const double density)
3186 {
3187 arma_extra_debug_sigprint();
3188
3189 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandu(): density must be in the [0,1] interval" );
3190
3191 zeros(in_rows, in_cols);
3192
3193 mem_resize( uword(density * double(in_rows) * double(in_cols) + 0.5) );
3194
3195 if(n_nonzero == 0)
3196 {
3197 return *this;
3198 }
3199
3200 eop_aux_randu<eT>::fill( access::rwp(values), n_nonzero );
3201
3202 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, n_nonzero );
3203
3204 // perturb the indices
3205 for(uword i=1; i < n_nonzero-1; ++i)
3206 {
3207 const uword index_left = indices[i-1];
3208 const uword index_right = indices[i+1];
3209
3210 const uword center = (index_left + index_right) / 2;
3211
3212 const uword delta1 = center - index_left - 1;
3213 const uword delta2 = index_right - center - 1;
3214
3215 const uword min_delta = (std::min)(delta1, delta2);
3216
3217 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
3218
3219 // paranoia, but better be safe than sorry
3220 if( (index_left < index_new) && (index_new < index_right) )
3221 {
3222 indices[i] = index_new;
3223 }
3224 }
3225
3226 uword cur_index = 0;
3227 uword count = 0;
3228
3229 for(uword lcol = 0; lcol < in_cols; ++lcol)
3230 for(uword lrow = 0; lrow < in_rows; ++lrow)
3231 {
3232 if(count == indices[cur_index])
3233 {
3234 access::rw(row_indices[cur_index]) = lrow;
3235 access::rw(col_ptrs[lcol + 1])++;
3236 ++cur_index;
3237 }
3238
3239 ++count;
3240 }
3241
3242 if(cur_index != n_nonzero)
3243 {
3244 // Fix size to correct size.
3245 mem_resize(cur_index);
3246 }
3247
3248 // Sum column pointers.
3249 for(uword lcol = 1; lcol <= in_cols; ++lcol)
3250 {
3251 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
3252 }
3253
3254 return *this;
3255 }
3256
3257
3258
3259 template<typename eT>
3260 inline
3261 const SpMat<eT>&
3262 SpMat<eT>::sprandn(const uword in_rows, const uword in_cols, const double density)
3263 {
3264 arma_extra_debug_sigprint();
3265
3266 arma_debug_check( ( (density < double(0)) || (density > double(1)) ), "sprandn(): density must be in the [0,1] interval" );
3267
3268 zeros(in_rows, in_cols);
3269
3270 mem_resize( uword(density * double(in_rows) * double(in_cols) + 0.5) );
3271
3272 if(n_nonzero == 0)
3273 {
3274 return *this;
3275 }
3276
3277 eop_aux_randn<eT>::fill( access::rwp(values), n_nonzero );
3278
3279 uvec indices = linspace<uvec>( 0u, in_rows*in_cols-1, n_nonzero );
3280
3281 // perturb the indices
3282 for(uword i=1; i < n_nonzero-1; ++i)
3283 {
3284 const uword index_left = indices[i-1];
3285 const uword index_right = indices[i+1];
3286
3287 const uword center = (index_left + index_right) / 2;
3288
3289 const uword delta1 = center - index_left - 1;
3290 const uword delta2 = index_right - center - 1;
3291
3292 const uword min_delta = (std::min)(delta1, delta2);
3293
3294 uword index_new = uword( double(center) + double(min_delta) * (2.0*randu()-1.0) );
3295
3296 // paranoia, but better be safe than sorry
3297 if( (index_left < index_new) && (index_new < index_right) )
3298 {
3299 indices[i] = index_new;
3300 }
3301 }
3302
3303 uword cur_index = 0;
3304 uword count = 0;
3305
3306 for(uword lcol = 0; lcol < in_cols; ++lcol)
3307 for(uword lrow = 0; lrow < in_rows; ++lrow)
3308 {
3309 if(count == indices[cur_index])
3310 {
3311 access::rw(row_indices[cur_index]) = lrow;
3312 access::rw(col_ptrs[lcol + 1])++;
3313 ++cur_index;
3314 }
3315
3316 ++count;
3317 }
3318
3319 if(cur_index != n_nonzero)
3320 {
3321 // Fix size to correct size.
3322 mem_resize(cur_index);
3323 }
3324
3325 // Sum column pointers.
3326 for(uword lcol = 1; lcol <= in_cols; ++lcol)
3327 {
3328 access::rw(col_ptrs[lcol]) += col_ptrs[lcol - 1];
3329 }
3330
3331 return *this;
3332 }
3333
3334
3335
3336 template<typename eT>
3337 inline
3338 void
3339 SpMat<eT>::reset()
3340 {
3341 arma_extra_debug_sigprint();
3342
3343 set_size(0, 0);
3344 }
3345
3346
3347
3348 /**
3349 * Get the minimum or the maximum of the matrix.
3350 */
3351 template<typename eT>
3352 inline
3353 arma_warn_unused
3354 eT
3355 SpMat<eT>::min() const
3356 {
3357 arma_extra_debug_sigprint();
3358
3359 arma_debug_check((n_elem == 0), "min(): object has no elements");
3360
3361 if (n_nonzero == 0)
3362 {
3363 return 0;
3364 }
3365
3366 eT val = op_min::direct_min(values, n_nonzero);
3367
3368 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less.
3369 {
3370 val = 0;
3371 }
3372
3373 return val;
3374 }
3375
3376
3377
3378 template<typename eT>
3379 inline
3380 eT
3381 SpMat<eT>::min(uword& index_of_min_val) const
3382 {
3383 arma_extra_debug_sigprint();
3384
3385 arma_debug_check((n_elem == 0), "min(): object has no elements");
3386
3387 eT val = 0;
3388
3389 if (n_nonzero == 0) // There are no other elements. It must be 0.
3390 {
3391 index_of_min_val = 0;
3392 }
3393 else
3394 {
3395 uword location;
3396 val = op_min::direct_min(values, n_nonzero, location);
3397
3398 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less.
3399 {
3400 val = 0;
3401
3402 // Give back the index to the first zero position.
3403 index_of_min_val = 0;
3404 while (get_position(index_of_min_val) == index_of_min_val) // An element exists at that position.
3405 {
3406 index_of_min_val++;
3407 }
3408
3409 }
3410 else
3411 {
3412 index_of_min_val = get_position(location);
3413 }
3414 }
3415
3416 return val;
3417
3418 }
3419
3420
3421
3422 template<typename eT>
3423 inline
3424 eT
3425 SpMat<eT>::min(uword& row_of_min_val, uword& col_of_min_val) const
3426 {
3427 arma_extra_debug_sigprint();
3428
3429 arma_debug_check((n_elem == 0), "min(): object has no elements");
3430
3431 eT val = 0;
3432
3433 if (n_nonzero == 0) // There are no other elements. It must be 0.
3434 {
3435 row_of_min_val = 0;
3436 col_of_min_val = 0;
3437 }
3438 else
3439 {
3440 uword location;
3441 val = op_min::direct_min(values, n_nonzero, location);
3442
3443 if ((val > 0) && (n_nonzero < n_elem)) // A sparse 0 is less.
3444 {
3445 val = 0;
3446
3447 location = 0;
3448 while (get_position(location) == location) // An element exists at that position.
3449 {
3450 location++;
3451 }
3452
3453 row_of_min_val = location % n_rows;
3454 col_of_min_val = location / n_rows;
3455 }
3456 else
3457 {
3458 get_position(location, row_of_min_val, col_of_min_val);
3459 }
3460 }
3461
3462 return val;
3463
3464 }
3465
3466
3467
3468 template<typename eT>
3469 inline
3470 arma_warn_unused
3471 eT
3472 SpMat<eT>::max() const
3473 {
3474 arma_extra_debug_sigprint();
3475
3476 arma_debug_check((n_elem == 0), "max(): object has no elements");
3477
3478 if (n_nonzero == 0)
3479 {
3480 return 0;
3481 }
3482
3483 eT val = op_max::direct_max(values, n_nonzero);
3484
3485 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more.
3486 {
3487 return 0;
3488 }
3489
3490 return val;
3491
3492 }
3493
3494
3495
3496 template<typename eT>
3497 inline
3498 eT
3499 SpMat<eT>::max(uword& index_of_max_val) const
3500 {
3501 arma_extra_debug_sigprint();
3502
3503 arma_debug_check((n_elem == 0), "max(): object has no elements");
3504
3505 eT val = 0;
3506
3507 if (n_nonzero == 0)
3508 {
3509 index_of_max_val = 0;
3510 }
3511 else
3512 {
3513 uword location;
3514 val = op_max::direct_max(values, n_nonzero, location);
3515
3516 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more.
3517 {
3518 val = 0;
3519
3520 location = 0;
3521 while (get_position(location) == location) // An element exists at that position.
3522 {
3523 location++;
3524 }
3525
3526 }
3527 else
3528 {
3529 index_of_max_val = get_position(location);
3530 }
3531
3532 }
3533
3534 return val;
3535
3536 }
3537
3538
3539
3540 template<typename eT>
3541 inline
3542 eT
3543 SpMat<eT>::max(uword& row_of_max_val, uword& col_of_max_val) const
3544 {
3545 arma_extra_debug_sigprint();
3546
3547 arma_debug_check((n_elem == 0), "max(): object has no elements");
3548
3549 eT val = 0;
3550
3551 if (n_nonzero == 0)
3552 {
3553 row_of_max_val = 0;
3554 col_of_max_val = 0;
3555 }
3556 else
3557 {
3558 uword location;
3559 val = op_max::direct_max(values, n_nonzero, location);
3560
3561 if ((val < 0) && (n_nonzero < n_elem)) // A sparse 0 is more.
3562 {
3563 val = 0;
3564
3565 location = 0;
3566 while (get_position(location) == location) // An element exists at that position.
3567 {
3568 location++;
3569 }
3570
3571 row_of_max_val = location % n_rows;
3572 col_of_max_val = location / n_rows;
3573
3574 }
3575 else
3576 {
3577 get_position(location, row_of_max_val, col_of_max_val);
3578 }
3579
3580 }
3581
3582 return val;
3583
3584 }
3585
3586
3587
3588 //! save the matrix to a file
3589 template<typename eT>
3590 inline
3591 bool
3592 SpMat<eT>::save(const std::string name, const file_type type, const bool print_status) const
3593 {
3594 arma_extra_debug_sigprint();
3595
3596 bool save_okay;
3597
3598 switch(type)
3599 {
3600 // case raw_ascii:
3601 // save_okay = diskio::save_raw_ascii(*this, name);
3602 // break;
3603
3604 // case csv_ascii:
3605 // save_okay = diskio::save_csv_ascii(*this, name);
3606 // break;
3607
3608 case arma_binary:
3609 save_okay = diskio::save_arma_binary(*this, name);
3610 break;
3611
3612 case coord_ascii:
3613 save_okay = diskio::save_coord_ascii(*this, name);
3614 break;
3615
3616 default:
3617 arma_warn(true, "SpMat::save(): unsupported file type");
3618 save_okay = false;
3619 }
3620
3621 arma_warn( (save_okay == false), "SpMat::save(): couldn't write to ", name);
3622
3623 return save_okay;
3624 }
3625
3626
3627
3628 //! save the matrix to a stream
3629 template<typename eT>
3630 inline
3631 bool
3632 SpMat<eT>::save(std::ostream& os, const file_type type, const bool print_status) const
3633 {
3634 arma_extra_debug_sigprint();
3635
3636 bool save_okay;
3637
3638 switch(type)
3639 {
3640 // case raw_ascii:
3641 // save_okay = diskio::save_raw_ascii(*this, os);
3642 // break;
3643
3644 // case csv_ascii:
3645 // save_okay = diskio::save_csv_ascii(*this, os);
3646 // break;
3647
3648 case arma_binary:
3649 save_okay = diskio::save_arma_binary(*this, os);
3650 break;
3651
3652 case coord_ascii:
3653 save_okay = diskio::save_coord_ascii(*this, os);
3654 break;
3655
3656 default:
3657 arma_warn(true, "SpMat::save(): unsupported file type");
3658 save_okay = false;
3659 }
3660
3661 arma_warn( (save_okay == false), "SpMat::save(): couldn't write to the given stream");
3662
3663 return save_okay;
3664 }
3665
3666
3667
3668 //! load a matrix from a file
3669 template<typename eT>
3670 inline
3671 bool
3672 SpMat<eT>::load(const std::string name, const file_type type, const bool print_status)
3673 {
3674 arma_extra_debug_sigprint();
3675
3676 bool load_okay;
3677 std::string err_msg;
3678
3679 switch(type)
3680 {
3681 // case auto_detect:
3682 // load_okay = diskio::load_auto_detect(*this, name, err_msg);
3683 // break;
3684
3685 // case raw_ascii:
3686 // load_okay = diskio::load_raw_ascii(*this, name, err_msg);
3687 // break;
3688
3689 // case csv_ascii:
3690 // load_okay = diskio::load_csv_ascii(*this, name, err_msg);
3691 // break;
3692
3693 case arma_binary:
3694 load_okay = diskio::load_arma_binary(*this, name, err_msg);
3695 break;
3696
3697 case coord_ascii:
3698 load_okay = diskio::load_coord_ascii(*this, name, err_msg);
3699 break;
3700
3701 default:
3702 arma_warn(true, "SpMat::load(): unsupported file type");
3703 load_okay = false;
3704 }
3705
3706 if(load_okay == false)
3707 {
3708 if(err_msg.length() > 0)
3709 {
3710 arma_warn(true, "SpMat::load(): ", err_msg, name);
3711 }
3712 else
3713 {
3714 arma_warn(true, "SpMat::load(): couldn't read ", name);
3715 }
3716 }
3717
3718 if(load_okay == false)
3719 {
3720 (*this).reset();
3721 }
3722
3723 return load_okay;
3724 }
3725
3726
3727
3728 //! load a matrix from a stream
3729 template<typename eT>
3730 inline
3731 bool
3732 SpMat<eT>::load(std::istream& is, const file_type type, const bool print_status)
3733 {
3734 arma_extra_debug_sigprint();
3735
3736 bool load_okay;
3737 std::string err_msg;
3738
3739 switch(type)
3740 {
3741 // case auto_detect:
3742 // load_okay = diskio::load_auto_detect(*this, is, err_msg);
3743 // break;
3744
3745 // case raw_ascii:
3746 // load_okay = diskio::load_raw_ascii(*this, is, err_msg);
3747 // break;
3748
3749 // case csv_ascii:
3750 // load_okay = diskio::load_csv_ascii(*this, is, err_msg);
3751 // break;
3752
3753 case arma_binary:
3754 load_okay = diskio::load_arma_binary(*this, is, err_msg);
3755 break;
3756
3757 case coord_ascii:
3758 load_okay = diskio::load_coord_ascii(*this, is, err_msg);
3759 break;
3760
3761 default:
3762 arma_warn(true, "SpMat::load(): unsupported file type");
3763 load_okay = false;
3764 }
3765
3766
3767 if(load_okay == false)
3768 {
3769 if(err_msg.length() > 0)
3770 {
3771 arma_warn(true, "SpMat::load(): ", err_msg, "the given stream");
3772 }
3773 else
3774 {
3775 arma_warn(true, "SpMat::load(): couldn't load from the given stream");
3776 }
3777 }
3778
3779 if(load_okay == false)
3780 {
3781 (*this).reset();
3782 }
3783
3784 return load_okay;
3785 }
3786
3787
3788
3789 //! save the matrix to a file, without printing any error messages
3790 template<typename eT>
3791 inline
3792 bool
3793 SpMat<eT>::quiet_save(const std::string name, const file_type type) const
3794 {
3795 arma_extra_debug_sigprint();
3796
3797 return (*this).save(name, type, false);
3798 }
3799
3800
3801
3802 //! save the matrix to a stream, without printing any error messages
3803 template<typename eT>
3804 inline
3805 bool
3806 SpMat<eT>::quiet_save(std::ostream& os, const file_type type) const
3807 {
3808 arma_extra_debug_sigprint();
3809
3810 return (*this).save(os, type, false);
3811 }
3812
3813
3814
3815 //! load a matrix from a file, without printing any error messages
3816 template<typename eT>
3817 inline
3818 bool
3819 SpMat<eT>::quiet_load(const std::string name, const file_type type)
3820 {
3821 arma_extra_debug_sigprint();
3822
3823 return (*this).load(name, type, false);
3824 }
3825
3826
3827
3828 //! load a matrix from a stream, without printing any error messages
3829 template<typename eT>
3830 inline
3831 bool
3832 SpMat<eT>::quiet_load(std::istream& is, const file_type type)
3833 {
3834 arma_extra_debug_sigprint();
3835
3836 return (*this).load(is, type, false);
3837 }
3838
3839
3840
3841 /**
3842 * Initialize the matrix to the specified size. Data is not preserved, so the matrix is assumed to be entirely sparse (empty).
3843 */
3844 template<typename eT>
3845 inline
3846 void
3847 SpMat<eT>::init(uword in_rows, uword in_cols)
3848 {
3849 arma_extra_debug_sigprint();
3850
3851 // Verify that we are allowed to do this.
3852 if(vec_state > 0)
3853 {
3854 if((in_rows == 0) && (in_cols == 0))
3855 {
3856 if(vec_state == 1)
3857 {
3858 in_cols = 1;
3859 }
3860 else
3861 if(vec_state == 2)
3862 {
3863 in_rows = 1;
3864 }
3865 }
3866 else
3867 {
3868 arma_debug_check
3869 (
3870 ( ((vec_state == 1) && (in_cols != 1)) || ((vec_state == 2) && (in_rows != 1)) ),
3871 "SpMat::init(): object is a row or column vector; requested size is not compatible"
3872 );
3873 }
3874 }
3875
3876 // Ensure that n_elem can hold the result of (n_rows * n_cols)
3877 arma_debug_check
3878 (
3879 (
3880 ( (in_rows > ARMA_MAX_UHWORD) || (in_cols > ARMA_MAX_UHWORD) )
3881 ? ( (float(in_rows) * float(in_cols)) > float(ARMA_MAX_UWORD) )
3882 : false
3883 ),
3884 "SpMat::init(): requested size is too large"
3885 );
3886
3887 // Clean out the existing memory.
3888 if (values)
3889 {
3890 memory::release(values);
3891 memory::release(row_indices);
3892 }
3893
3894 access::rw(values) = memory::acquire_chunked<eT> (1);
3895 access::rw(row_indices) = memory::acquire_chunked<uword>(1);
3896
3897 access::rw(values[0]) = 0;
3898 access::rw(row_indices[0]) = 0;
3899
3900 memory::release(col_ptrs);
3901
3902 // Set the new size accordingly.
3903 access::rw(n_rows) = in_rows;
3904 access::rw(n_cols) = in_cols;
3905 access::rw(n_elem) = (in_rows * in_cols);
3906 access::rw(n_nonzero) = 0;
3907
3908 // Try to allocate the column pointers, filling them with 0, except for the
3909 // last element which contains the maximum possible element (so iterators
3910 // terminate correctly).
3911 access::rw(col_ptrs) = memory::acquire<uword>(in_cols + 2);
3912 access::rw(col_ptrs[in_cols + 1]) = std::numeric_limits<uword>::max();
3913
3914 arrayops::inplace_set(access::rwp(col_ptrs), uword(0), in_cols + 1);
3915 }
3916
3917
3918
3919 /**
3920 * Initialize the matrix from a string.
3921 */
3922 template<typename eT>
3923 inline
3924 void
3925 SpMat<eT>::init(const std::string& text)
3926 {
3927 arma_extra_debug_sigprint();
3928
3929 // Figure out the size first.
3930 uword t_n_rows = 0;
3931 uword t_n_cols = 0;
3932
3933 bool t_n_cols_found = false;
3934
3935 std::string token;
3936
3937 std::string::size_type line_start = 0;
3938 std::string::size_type line_end = 0;
3939
3940 while (line_start < text.length())
3941 {
3942
3943 line_end = text.find(';', line_start);
3944
3945 if (line_end == std::string::npos)
3946 line_end = text.length() - 1;
3947
3948 std::string::size_type line_len = line_end - line_start + 1;
3949 std::stringstream line_stream(text.substr(line_start, line_len));
3950
3951 // Step through each column.
3952 uword line_n_cols = 0;
3953
3954 while (line_stream >> token)
3955 {
3956 ++line_n_cols;
3957 }
3958
3959 if (line_n_cols > 0)
3960 {
3961 if (t_n_cols_found == false)
3962 {
3963 t_n_cols = line_n_cols;
3964 t_n_cols_found = true;
3965 }
3966 else // Check it each time through, just to make sure.
3967 arma_check((line_n_cols != t_n_cols), "SpMat::init(): inconsistent number of columns in given string");
3968
3969 ++t_n_rows;
3970 }
3971
3972 line_start = line_end + 1;
3973
3974 }
3975
3976 set_size(t_n_rows, t_n_cols);
3977
3978 // Second time through will pick up all the values.
3979 line_start = 0;
3980 line_end = 0;
3981
3982 uword lrow = 0;
3983
3984 while (line_start < text.length())
3985 {
3986
3987 line_end = text.find(';', line_start);
3988
3989 if (line_end == std::string::npos)
3990 line_end = text.length() - 1;
3991
3992 std::string::size_type line_len = line_end - line_start + 1;
3993 std::stringstream line_stream(text.substr(line_start, line_len));
3994
3995 uword lcol = 0;
3996 eT val;
3997
3998 while (line_stream >> val)
3999 {
4000 // Only add nonzero elements.
4001 if (val != eT(0))
4002 {
4003 get_value(lrow, lcol) = val;
4004 }
4005
4006 ++lcol;
4007 }
4008
4009 ++lrow;
4010 line_start = line_end + 1;
4011
4012 }
4013
4014 }
4015
4016 /**
4017 * Copy from another matrix.
4018 */
4019 template<typename eT>
4020 inline
4021 void
4022 SpMat<eT>::init(const SpMat<eT>& x)
4023 {
4024 arma_extra_debug_sigprint();
4025
4026 // Ensure we are not initializing to ourselves.
4027 if (this != &x)
4028 {
4029 init(x.n_rows, x.n_cols);
4030
4031 // values and row_indices may not be null.
4032 if (values != NULL)
4033 {
4034 memory::release(values);
4035 memory::release(row_indices);
4036 }
4037
4038 access::rw(values) = memory::acquire_chunked<eT> (x.n_nonzero + 1);
4039 access::rw(row_indices) = memory::acquire_chunked<uword>(x.n_nonzero + 1);
4040
4041 // Now copy over the elements.
4042 arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1);
4043 arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1);
4044 arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1);
4045
4046 access::rw(n_nonzero) = x.n_nonzero;
4047 }
4048 }
4049
4050
4051
4052 template<typename eT>
4053 inline
4054 void
4055 SpMat<eT>::mem_resize(const uword new_n_nonzero)
4056 {
4057 arma_extra_debug_sigprint();
4058
4059 if(n_nonzero != new_n_nonzero)
4060 {
4061 if(new_n_nonzero == 0)
4062 {
4063 memory::release(values);
4064 memory::release(row_indices);
4065
4066 access::rw(values) = memory::acquire_chunked<eT> (1);
4067 access::rw(row_indices) = memory::acquire_chunked<uword>(1);
4068
4069 access::rw(values[0]) = 0;
4070 access::rw(row_indices[0]) = 0;
4071 }
4072 else
4073 {
4074 // Figure out the actual amount of memory currently allocated
4075 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays
4076 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(n_nonzero);
4077
4078 if(n_alloc < new_n_nonzero)
4079 {
4080 eT* new_values = memory::acquire_chunked<eT> (new_n_nonzero + 1);
4081 uword* new_row_indices = memory::acquire_chunked<uword>(new_n_nonzero + 1);
4082
4083 if(n_nonzero > 0)
4084 {
4085 // Copy old elements.
4086 uword copy_len = std::min(n_nonzero, new_n_nonzero);
4087
4088 arrayops::copy(new_values, values, copy_len);
4089 arrayops::copy(new_row_indices, row_indices, copy_len);
4090 }
4091
4092 memory::release(values);
4093 memory::release(row_indices);
4094
4095 access::rw(values) = new_values;
4096 access::rw(row_indices) = new_row_indices;
4097 }
4098
4099 // Set the "fake end" of the matrix by setting the last value and row
4100 // index to 0. This helps the iterators work correctly.
4101 access::rw(values[new_n_nonzero]) = 0;
4102 access::rw(row_indices[new_n_nonzero]) = 0;
4103 }
4104
4105 access::rw(n_nonzero) = new_n_nonzero;
4106 }
4107 }
4108
4109
4110
4111 // Steal memory from another matrix.
4112 template<typename eT>
4113 inline
4114 void
4115 SpMat<eT>::steal_mem(SpMat<eT>& x)
4116 {
4117 arma_extra_debug_sigprint();
4118
4119 if(this != &x)
4120 {
4121 // Release all the memory.
4122 memory::release(values);
4123 memory::release(row_indices);
4124 memory::release(col_ptrs);
4125
4126 // We'll have to copy everything about the other matrix.
4127 const uword x_n_rows = x.n_rows;
4128 const uword x_n_cols = x.n_cols;
4129 const uword x_n_elem = x.n_elem;
4130 const uword x_n_nonzero = x.n_nonzero;
4131
4132 access::rw(n_rows) = x_n_rows;
4133 access::rw(n_cols) = x_n_cols;
4134 access::rw(n_elem) = x_n_elem;
4135 access::rw(n_nonzero) = x_n_nonzero;
4136
4137 access::rw(values) = x.values;
4138 access::rw(row_indices) = x.row_indices;
4139 access::rw(col_ptrs) = x.col_ptrs;
4140
4141 // Set other matrix to empty.
4142 access::rw(x.n_rows) = 0;
4143 access::rw(x.n_cols) = 0;
4144 access::rw(x.n_elem) = 0;
4145 access::rw(x.n_nonzero) = 0;
4146
4147 access::rw(x.values) = NULL;
4148 access::rw(x.row_indices) = NULL;
4149 access::rw(x.col_ptrs) = NULL;
4150 }
4151 }
4152
4153
4154
4155 template<typename eT>
4156 template<typename T1, typename Functor>
4157 arma_hot
4158 inline
4159 void
4160 SpMat<eT>::init_xform(const SpBase<eT,T1>& A, const Functor& func)
4161 {
4162 arma_extra_debug_sigprint();
4163
4164 // if possible, avoid doing a copy and instead apply func to the generated elements
4165 if(SpProxy<T1>::Q_created_by_proxy == true)
4166 {
4167 (*this) = A.get_ref();
4168
4169 const uword nnz = n_nonzero;
4170
4171 eT* t_values = access::rwp(values);
4172
4173 for(uword i=0; i < nnz; ++i)
4174 {
4175 t_values[i] = func(t_values[i]);
4176 }
4177 }
4178 else
4179 {
4180 init_xform_mt(A.get_ref(), func);
4181 }
4182 }
4183
4184
4185
4186 template<typename eT>
4187 template<typename eT2, typename T1, typename Functor>
4188 arma_hot
4189 inline
4190 void
4191 SpMat<eT>::init_xform_mt(const SpBase<eT2,T1>& A, const Functor& func)
4192 {
4193 arma_extra_debug_sigprint();
4194
4195 const SpProxy<T1> P(A.get_ref());
4196
4197 if( (P.is_alias(*this) == true) || (is_SpMat<typename SpProxy<T1>::stored_type>::value == true) )
4198 {
4199 // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices;
4200 // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary
4201 const unwrap_spmat<typename SpProxy<T1>::stored_type> tmp(P.Q);
4202
4203 const SpMat<eT2>& x = tmp.M;
4204
4205 if(void_ptr(this) != void_ptr(&x))
4206 {
4207 init(x.n_rows, x.n_cols);
4208
4209 // values and row_indices may not be null.
4210 if(values != NULL)
4211 {
4212 memory::release(values);
4213 memory::release(row_indices);
4214 }
4215
4216 access::rw(values) = memory::acquire_chunked<eT> (x.n_nonzero + 1);
4217 access::rw(row_indices) = memory::acquire_chunked<uword>(x.n_nonzero + 1);
4218
4219 arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1);
4220 arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1);
4221
4222 access::rw(n_nonzero) = x.n_nonzero;
4223 }
4224
4225
4226 // initialise the elements array with a transformed version of the elements from x
4227
4228 const uword nnz = n_nonzero;
4229
4230 const eT2* x_values = x.values;
4231 eT* t_values = access::rwp(values);
4232
4233 for(uword i=0; i < nnz; ++i)
4234 {
4235 t_values[i] = func(x_values[i]); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
4236 }
4237 }
4238 else
4239 {
4240 init(P.get_n_rows(), P.get_n_cols());
4241
4242 mem_resize(P.get_n_nonzero());
4243
4244 typename SpProxy<T1>::const_iterator_type it = P.begin();
4245
4246 while(it != P.end())
4247 {
4248 access::rw(row_indices[it.pos()]) = it.row();
4249 access::rw(values[it.pos()]) = func(*it); // NOTE: func() must produce a value of type eT (ie. act as a convertor between eT2 and eT)
4250 ++access::rw(col_ptrs[it.col() + 1]);
4251 ++it;
4252 }
4253
4254 // Now sum column pointers.
4255 for(uword c = 1; c <= n_cols; ++c)
4256 {
4257 access::rw(col_ptrs[c]) += col_ptrs[c - 1];
4258 }
4259 }
4260 }
4261
4262
4263
4264 template<typename eT>
4265 inline
4266 typename SpMat<eT>::iterator
4267 SpMat<eT>::begin()
4268 {
4269 return iterator(*this);
4270 }
4271
4272
4273
4274 template<typename eT>
4275 inline
4276 typename SpMat<eT>::const_iterator
4277 SpMat<eT>::begin() const
4278 {
4279 return const_iterator(*this);
4280 }
4281
4282
4283
4284 template<typename eT>
4285 inline
4286 typename SpMat<eT>::iterator
4287 SpMat<eT>::end()
4288 {
4289 return iterator(*this, 0, n_cols, n_nonzero);
4290 }
4291
4292
4293
4294 template<typename eT>
4295 inline
4296 typename SpMat<eT>::const_iterator
4297 SpMat<eT>::end() const
4298 {
4299 return const_iterator(*this, 0, n_cols, n_nonzero);
4300 }
4301
4302
4303
4304 template<typename eT>
4305 inline
4306 typename SpMat<eT>::iterator
4307 SpMat<eT>::begin_col(const uword col_num)
4308 {
4309 return iterator(*this, 0, col_num);
4310 }
4311
4312
4313
4314 template<typename eT>
4315 inline
4316 typename SpMat<eT>::const_iterator
4317 SpMat<eT>::begin_col(const uword col_num) const
4318 {
4319 return const_iterator(*this, 0, col_num);
4320 }
4321
4322
4323
4324 template<typename eT>
4325 inline
4326 typename SpMat<eT>::iterator
4327 SpMat<eT>::end_col(const uword col_num)
4328 {
4329 return iterator(*this, 0, col_num + 1);
4330 }
4331
4332
4333
4334 template<typename eT>
4335 inline
4336 typename SpMat<eT>::const_iterator
4337 SpMat<eT>::end_col(const uword col_num) const
4338 {
4339 return const_iterator(*this, 0, col_num + 1);
4340 }
4341
4342
4343
4344 template<typename eT>
4345 inline
4346 typename SpMat<eT>::row_iterator
4347 SpMat<eT>::begin_row(const uword row_num)
4348 {
4349 return row_iterator(*this, row_num, 0);
4350 }
4351
4352
4353
4354 template<typename eT>
4355 inline
4356 typename SpMat<eT>::const_row_iterator
4357 SpMat<eT>::begin_row(const uword row_num) const
4358 {
4359 return const_row_iterator(*this, row_num, 0);
4360 }
4361
4362
4363
4364 template<typename eT>
4365 inline
4366 typename SpMat<eT>::row_iterator
4367 SpMat<eT>::end_row()
4368 {
4369 return row_iterator(*this, n_nonzero);
4370 }
4371
4372
4373
4374 template<typename eT>
4375 inline
4376 typename SpMat<eT>::const_row_iterator
4377 SpMat<eT>::end_row() const
4378 {
4379 return const_row_iterator(*this, n_nonzero);
4380 }
4381
4382
4383
4384 template<typename eT>
4385 inline
4386 typename SpMat<eT>::row_iterator
4387 SpMat<eT>::end_row(const uword row_num)
4388 {
4389 return row_iterator(*this, row_num + 1, 0);
4390 }
4391
4392
4393
4394 template<typename eT>
4395 inline
4396 typename SpMat<eT>::const_row_iterator
4397 SpMat<eT>::end_row(const uword row_num) const
4398 {
4399 return const_row_iterator(*this, row_num + 1, 0);
4400 }
4401
4402
4403
4404 template<typename eT>
4405 inline
4406 void
4407 SpMat<eT>::clear()
4408 {
4409 if (values)
4410 {
4411 memory::release(values);
4412 memory::release(row_indices);
4413
4414 access::rw(values) = memory::acquire_chunked<eT> (1);
4415 access::rw(row_indices) = memory::acquire_chunked<uword>(1);
4416
4417 access::rw(values[0]) = 0;
4418 access::rw(row_indices[0]) = 0;
4419 }
4420
4421 memory::release(col_ptrs);
4422
4423 access::rw(col_ptrs) = memory::acquire<uword>(n_cols + 2);
4424 access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits<uword>::max();
4425
4426 arrayops::inplace_set(col_ptrs, eT(0), n_cols + 1);
4427
4428 access::rw(n_nonzero) = 0;
4429 }
4430
4431
4432
4433 template<typename eT>
4434 inline
4435 bool
4436 SpMat<eT>::empty() const
4437 {
4438 return (n_elem == 0);
4439 }
4440
4441
4442
4443 template<typename eT>
4444 inline
4445 uword
4446 SpMat<eT>::size() const
4447 {
4448 return n_elem;
4449 }
4450
4451
4452
4453 template<typename eT>
4454 inline
4455 arma_hot
4456 arma_warn_unused
4457 SpValProxy<SpMat<eT> >
4458 SpMat<eT>::get_value(const uword i)
4459 {
4460 // First convert to the actual location.
4461 uword lcol = i / n_rows; // Integer division.
4462 uword lrow = i % n_rows;
4463
4464 return get_value(lrow, lcol);
4465 }
4466
4467
4468
4469 template<typename eT>
4470 inline
4471 arma_hot
4472 arma_warn_unused
4473 eT
4474 SpMat<eT>::get_value(const uword i) const
4475 {
4476 // First convert to the actual location.
4477 uword lcol = i / n_rows; // Integer division.
4478 uword lrow = i % n_rows;
4479
4480 return get_value(lrow, lcol);
4481 }
4482
4483
4484
4485 template<typename eT>
4486 inline
4487 arma_hot
4488 arma_warn_unused
4489 SpValProxy<SpMat<eT> >
4490 SpMat<eT>::get_value(const uword in_row, const uword in_col)
4491 {
4492 const uword colptr = col_ptrs[in_col];
4493 const uword next_colptr = col_ptrs[in_col + 1];
4494
4495 // Step through the row indices to see if our element exists.
4496 for (uword i = colptr; i < next_colptr; ++i)
4497 {
4498 const uword row_index = row_indices[i];
4499
4500 // First check that we have not stepped past it.
4501 if (in_row < row_index) // If we have, then it doesn't exist: return 0.
4502 {
4503 return SpValProxy<SpMat<eT> >(in_row, in_col, *this); // Proxy for a zero value.
4504 }
4505
4506 // Now check if we are at the correct place.
4507 if (in_row == row_index) // If we are, return a reference to the value.
4508 {
4509 return SpValProxy<SpMat<eT> >(in_row, in_col, *this, &access::rw(values[i]));
4510 }
4511
4512 }
4513
4514 // We did not find it, so it does not exist: return 0.
4515 return SpValProxy<SpMat<eT> >(in_row, in_col, *this);
4516 }
4517
4518
4519
4520 template<typename eT>
4521 inline
4522 arma_hot
4523 arma_warn_unused
4524 eT
4525 SpMat<eT>::get_value(const uword in_row, const uword in_col) const
4526 {
4527 const uword colptr = col_ptrs[in_col];
4528 const uword next_colptr = col_ptrs[in_col + 1];
4529
4530 // Step through the row indices to see if our element exists.
4531 for (uword i = colptr; i < next_colptr; ++i)
4532 {
4533 const uword row_index = row_indices[i];
4534
4535 // First check that we have not stepped past it.
4536 if (in_row < row_index) // If we have, then it doesn't exist: return 0.
4537 {
4538 return eT(0);
4539 }
4540
4541 // Now check if we are at the correct place.
4542 if (in_row == row_index) // If we are, return the value.
4543 {
4544 return values[i];
4545 }
4546 }
4547
4548 // We did not find it, so it does not exist: return 0.
4549 return eT(0);
4550 }
4551
4552
4553
4554 /**
4555 * Given the index representing which of the nonzero values this is, return its
4556 * actual location, either in row/col or just the index.
4557 */
4558 template<typename eT>
4559 arma_hot
4560 arma_inline
4561 arma_warn_unused
4562 uword
4563 SpMat<eT>::get_position(const uword i) const
4564 {
4565 uword lrow, lcol;
4566
4567 get_position(i, lrow, lcol);
4568
4569 // Assemble the row/col into the element's location in the matrix.
4570 return (lrow + n_rows * lcol);
4571 }
4572
4573
4574
4575 template<typename eT>
4576 arma_hot
4577 arma_inline
4578 void
4579 SpMat<eT>::get_position(const uword i, uword& row_of_i, uword& col_of_i) const
4580 {
4581 arma_debug_check((i >= n_nonzero), "SpMat::get_position(): index out of bounds");
4582
4583 col_of_i = 0;
4584 while (col_ptrs[col_of_i + 1] <= i)
4585 {
4586 col_of_i++;
4587 }
4588
4589 row_of_i = row_indices[i];
4590
4591 return;
4592 }
4593
4594
4595
4596 /**
4597 * Add an element at the given position, and return a reference to it. The
4598 * element will be set to 0 (unless otherwise specified). If the element
4599 * already exists, its value will be overwritten.
4600 *
4601 * @param in_row Row of new element.
4602 * @param in_col Column of new element.
4603 * @param in_val Value to set new element to (default 0.0).
4604 */
4605 template<typename eT>
4606 inline
4607 arma_hot
4608 arma_warn_unused
4609 eT&
4610 SpMat<eT>::add_element(const uword in_row, const uword in_col, const eT val)
4611 {
4612 arma_extra_debug_sigprint();
4613
4614 // We will assume the new element does not exist and begin the search for
4615 // where to insert it. If we find that it already exists, we will then
4616 // overwrite it.
4617 uword colptr = col_ptrs[in_col ];
4618 uword next_colptr = col_ptrs[in_col + 1];
4619
4620 uword pos = colptr; // The position in the matrix of this value.
4621
4622 if (colptr != next_colptr)
4623 {
4624 // There are other elements in this column, so we must find where this
4625 // element will fit as compared to those.
4626 while (pos < next_colptr && in_row > row_indices[pos])
4627 {
4628 pos++;
4629 }
4630
4631 // We aren't inserting into the last position, so it is still possible
4632 // that the element may exist.
4633 if (pos != next_colptr && row_indices[pos] == in_row)
4634 {
4635 // It already exists. Then, just overwrite it.
4636 access::rw(values[pos]) = val;
4637
4638 return access::rw(values[pos]);
4639 }
4640 }
4641
4642
4643 //
4644 // Element doesn't exist, so we have to insert it
4645 //
4646
4647 // We have to update the rest of the column pointers.
4648 for (uword i = in_col + 1; i < n_cols + 1; i++)
4649 {
4650 access::rw(col_ptrs[i])++; // We are only inserting one new element.
4651 }
4652
4653
4654 // Figure out the actual amount of memory currently allocated
4655 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays
4656 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(n_nonzero + 1);
4657
4658 // If possible, avoid time-consuming memory allocation
4659 if(n_alloc > (n_nonzero + 1))
4660 {
4661 arrayops::copy_backwards(access::rwp(values) + pos + 1, values + pos, (n_nonzero - pos) + 1);
4662 arrayops::copy_backwards(access::rwp(row_indices) + pos + 1, row_indices + pos, (n_nonzero - pos) + 1);
4663
4664 // Insert the new element.
4665 access::rw(values[pos]) = val;
4666 access::rw(row_indices[pos]) = in_row;
4667
4668 access::rw(n_nonzero)++;
4669 }
4670 else
4671 {
4672 const uword old_n_nonzero = n_nonzero;
4673
4674 access::rw(n_nonzero)++; // Add to count of nonzero elements.
4675
4676 // Allocate larger memory.
4677 eT* new_values = memory::acquire_chunked<eT> (n_nonzero + 1);
4678 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1);
4679
4680 // Copy things over, before the new element.
4681 if (pos > 0)
4682 {
4683 arrayops::copy(new_values, values, pos);
4684 arrayops::copy(new_row_indices, row_indices, pos);
4685 }
4686
4687 // Insert the new element.
4688 new_values[pos] = val;
4689 new_row_indices[pos] = in_row;
4690
4691 // Copy the rest of things over (including the extra element at the end).
4692 arrayops::copy(new_values + pos + 1, values + pos, (old_n_nonzero - pos) + 1);
4693 arrayops::copy(new_row_indices + pos + 1, row_indices + pos, (old_n_nonzero - pos) + 1);
4694
4695 // Assign new pointers.
4696 memory::release(values);
4697 memory::release(row_indices);
4698
4699 access::rw(values) = new_values;
4700 access::rw(row_indices) = new_row_indices;
4701 }
4702
4703 return access::rw(values[pos]);
4704 }
4705
4706
4707
4708 /**
4709 * Delete an element at the given position.
4710 *
4711 * @param in_row Row of element to be deleted.
4712 * @param in_col Column of element to be deleted.
4713 */
4714 template<typename eT>
4715 inline
4716 arma_hot
4717 void
4718 SpMat<eT>::delete_element(const uword in_row, const uword in_col)
4719 {
4720 arma_extra_debug_sigprint();
4721
4722 // We assume the element exists (although... it may not) and look for its
4723 // exact position. If it doesn't exist... well, we don't need to do anything.
4724 uword colptr = col_ptrs[in_col];
4725 uword next_colptr = col_ptrs[in_col + 1];
4726
4727 if (colptr != next_colptr)
4728 {
4729 // There's at least one element in this column.
4730 // Let's see if we are one of them.
4731 for (uword pos = colptr; pos < next_colptr; pos++)
4732 {
4733 if (in_row == row_indices[pos])
4734 {
4735 const uword old_n_nonzero = n_nonzero;
4736
4737 --access::rw(n_nonzero); // Remove one from the count of nonzero elements.
4738
4739 // Found it. Now remove it.
4740
4741 // Figure out the actual amount of memory currently allocated and the actual amount that will be required
4742 // NOTE: this relies on memory::acquire_chunked() being used for the 'values' and 'row_indices' arrays
4743
4744 const uword n_alloc = memory::enlarge_to_mult_of_chunksize(old_n_nonzero + 1);
4745 const uword n_alloc_mod = memory::enlarge_to_mult_of_chunksize(n_nonzero + 1);
4746
4747 // If possible, avoid time-consuming memory allocation
4748 if(n_alloc_mod == n_alloc)
4749 {
4750 if (pos < n_nonzero) // remember, we decremented n_nonzero
4751 {
4752 arrayops::copy_forwards(access::rwp(values) + pos, values + pos + 1, (n_nonzero - pos) + 1);
4753 arrayops::copy_forwards(access::rwp(row_indices) + pos, row_indices + pos + 1, (n_nonzero - pos) + 1);
4754 }
4755 }
4756 else
4757 {
4758 // Make new arrays.
4759 eT* new_values = memory::acquire_chunked<eT> (n_nonzero + 1);
4760 uword* new_row_indices = memory::acquire_chunked<uword>(n_nonzero + 1);
4761
4762 if (pos > 0)
4763 {
4764 arrayops::copy(new_values, values, pos);
4765 arrayops::copy(new_row_indices, row_indices, pos);
4766 }
4767
4768 arrayops::copy(new_values + pos, values + pos + 1, (n_nonzero - pos) + 1);
4769 arrayops::copy(new_row_indices + pos, row_indices + pos + 1, (n_nonzero - pos) + 1);
4770
4771 memory::release(values);
4772 memory::release(row_indices);
4773
4774 access::rw(values) = new_values;
4775 access::rw(row_indices) = new_row_indices;
4776 }
4777
4778 // And lastly, update all the column pointers (decrement by one).
4779 for (uword i = in_col + 1; i < n_cols + 1; i++)
4780 {
4781 --access::rw(col_ptrs[i]); // We only removed one element.
4782 }
4783
4784 return; // There is nothing left to do.
4785 }
4786 }
4787 }
4788
4789 return; // The element does not exist, so there's nothing for us to do.
4790 }
4791
4792
4793
4794 #ifdef ARMA_EXTRA_SPMAT_MEAT
4795 #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT)
4796 #endif
4797
4798
4799
4800 //! @}