Chris@49
|
1 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
2 // Copyright (C) 2012 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup spglue_plus
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<typename T1, typename T2>
|
Chris@49
|
15 arma_hot
|
Chris@49
|
16 inline
|
Chris@49
|
17 void
|
Chris@49
|
18 spglue_plus::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_plus>& X)
|
Chris@49
|
19 {
|
Chris@49
|
20 arma_extra_debug_sigprint();
|
Chris@49
|
21
|
Chris@49
|
22 typedef typename T1::elem_type eT;
|
Chris@49
|
23
|
Chris@49
|
24 const SpProxy<T1> pa(X.A);
|
Chris@49
|
25 const SpProxy<T2> pb(X.B);
|
Chris@49
|
26
|
Chris@49
|
27 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
|
Chris@49
|
28
|
Chris@49
|
29 if(is_alias == false)
|
Chris@49
|
30 {
|
Chris@49
|
31 spglue_plus::apply_noalias(out, pa, pb);
|
Chris@49
|
32 }
|
Chris@49
|
33 else
|
Chris@49
|
34 {
|
Chris@49
|
35 SpMat<eT> tmp;
|
Chris@49
|
36 spglue_plus::apply_noalias(tmp, pa, pb);
|
Chris@49
|
37
|
Chris@49
|
38 out.steal_mem(tmp);
|
Chris@49
|
39 }
|
Chris@49
|
40 }
|
Chris@49
|
41
|
Chris@49
|
42
|
Chris@49
|
43
|
Chris@49
|
44 template<typename eT, typename T1, typename T2>
|
Chris@49
|
45 arma_hot
|
Chris@49
|
46 inline
|
Chris@49
|
47 void
|
Chris@49
|
48 spglue_plus::apply_noalias(SpMat<eT>& out, const SpProxy<T1>& pa, const SpProxy<T2>& pb)
|
Chris@49
|
49 {
|
Chris@49
|
50 arma_extra_debug_sigprint();
|
Chris@49
|
51
|
Chris@49
|
52 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "addition");
|
Chris@49
|
53
|
Chris@49
|
54 if( (pa.get_n_nonzero() != 0) && (pb.get_n_nonzero() != 0) )
|
Chris@49
|
55 {
|
Chris@49
|
56 out.set_size(pa.get_n_rows(), pa.get_n_cols());
|
Chris@49
|
57
|
Chris@49
|
58 // Resize memory to correct size.
|
Chris@49
|
59 out.mem_resize(n_unique(pa, pb, op_n_unique_add()));
|
Chris@49
|
60
|
Chris@49
|
61 // Now iterate across both matrices.
|
Chris@49
|
62 typename SpProxy<T1>::const_iterator_type x_it = pa.begin();
|
Chris@49
|
63 typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
|
Chris@49
|
64
|
Chris@49
|
65 typename SpProxy<T1>::const_iterator_type x_end = pa.end();
|
Chris@49
|
66 typename SpProxy<T2>::const_iterator_type y_end = pb.end();
|
Chris@49
|
67
|
Chris@49
|
68 uword cur_val = 0;
|
Chris@49
|
69 while( (x_it != x_end) || (y_it != y_end) )
|
Chris@49
|
70 {
|
Chris@49
|
71 if(x_it == y_it)
|
Chris@49
|
72 {
|
Chris@49
|
73 const eT val = (*x_it) + (*y_it);
|
Chris@49
|
74
|
Chris@49
|
75 if (val != eT(0))
|
Chris@49
|
76 {
|
Chris@49
|
77 access::rw(out.values[cur_val]) = val;
|
Chris@49
|
78 access::rw(out.row_indices[cur_val]) = x_it.row();
|
Chris@49
|
79 ++access::rw(out.col_ptrs[x_it.col() + 1]);
|
Chris@49
|
80 ++cur_val;
|
Chris@49
|
81 }
|
Chris@49
|
82
|
Chris@49
|
83 ++x_it;
|
Chris@49
|
84 ++y_it;
|
Chris@49
|
85 }
|
Chris@49
|
86 else
|
Chris@49
|
87 {
|
Chris@49
|
88 const uword x_it_row = x_it.row();
|
Chris@49
|
89 const uword x_it_col = x_it.col();
|
Chris@49
|
90
|
Chris@49
|
91 const uword y_it_row = y_it.row();
|
Chris@49
|
92 const uword y_it_col = y_it.col();
|
Chris@49
|
93
|
Chris@49
|
94 if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
|
Chris@49
|
95 {
|
Chris@49
|
96 access::rw(out.values[cur_val]) = (*x_it);
|
Chris@49
|
97 access::rw(out.row_indices[cur_val]) = x_it_row;
|
Chris@49
|
98 ++access::rw(out.col_ptrs[x_it_col + 1]);
|
Chris@49
|
99 ++cur_val;
|
Chris@49
|
100 ++x_it;
|
Chris@49
|
101 }
|
Chris@49
|
102 else
|
Chris@49
|
103 {
|
Chris@49
|
104 access::rw(out.values[cur_val]) = (*y_it);
|
Chris@49
|
105 access::rw(out.row_indices[cur_val]) = y_it_row;
|
Chris@49
|
106 ++access::rw(out.col_ptrs[y_it_col + 1]);
|
Chris@49
|
107 ++cur_val;
|
Chris@49
|
108 ++y_it;
|
Chris@49
|
109 }
|
Chris@49
|
110 }
|
Chris@49
|
111 }
|
Chris@49
|
112
|
Chris@49
|
113 const uword out_n_cols = out.n_cols;
|
Chris@49
|
114
|
Chris@49
|
115 uword* col_ptrs = access::rwp(out.col_ptrs);
|
Chris@49
|
116
|
Chris@49
|
117 // Fix column pointers to be cumulative.
|
Chris@49
|
118 for(uword c = 1; c <= out_n_cols; ++c)
|
Chris@49
|
119 {
|
Chris@49
|
120 col_ptrs[c] += col_ptrs[c - 1];
|
Chris@49
|
121 }
|
Chris@49
|
122 }
|
Chris@49
|
123 else
|
Chris@49
|
124 {
|
Chris@49
|
125 if(pa.get_n_nonzero() == 0)
|
Chris@49
|
126 {
|
Chris@49
|
127 out = pb.Q;
|
Chris@49
|
128 return;
|
Chris@49
|
129 }
|
Chris@49
|
130
|
Chris@49
|
131 if(pb.get_n_nonzero() == 0)
|
Chris@49
|
132 {
|
Chris@49
|
133 out = pa.Q;
|
Chris@49
|
134 return;
|
Chris@49
|
135 }
|
Chris@49
|
136 }
|
Chris@49
|
137 }
|
Chris@49
|
138
|
Chris@49
|
139
|
Chris@49
|
140
|
Chris@49
|
141 //
|
Chris@49
|
142 //
|
Chris@49
|
143 // spglue_plus2: scalar*(A + B)
|
Chris@49
|
144
|
Chris@49
|
145
|
Chris@49
|
146
|
Chris@49
|
147 template<typename T1, typename T2>
|
Chris@49
|
148 arma_hot
|
Chris@49
|
149 inline
|
Chris@49
|
150 void
|
Chris@49
|
151 spglue_plus2::apply(SpMat<typename T1::elem_type>& out, const SpGlue<T1,T2,spglue_plus2>& X)
|
Chris@49
|
152 {
|
Chris@49
|
153 arma_extra_debug_sigprint();
|
Chris@49
|
154
|
Chris@49
|
155 typedef typename T1::elem_type eT;
|
Chris@49
|
156
|
Chris@49
|
157 const SpProxy<T1> pa(X.A);
|
Chris@49
|
158 const SpProxy<T2> pb(X.B);
|
Chris@49
|
159
|
Chris@49
|
160 const bool is_alias = pa.is_alias(out) || pb.is_alias(out);
|
Chris@49
|
161
|
Chris@49
|
162 if(is_alias == false)
|
Chris@49
|
163 {
|
Chris@49
|
164 spglue_plus::apply_noalias(out, pa, pb);
|
Chris@49
|
165 }
|
Chris@49
|
166 else
|
Chris@49
|
167 {
|
Chris@49
|
168 SpMat<eT> tmp;
|
Chris@49
|
169 spglue_plus::apply_noalias(tmp, pa, pb);
|
Chris@49
|
170
|
Chris@49
|
171 out.steal_mem(tmp);
|
Chris@49
|
172 }
|
Chris@49
|
173
|
Chris@49
|
174 out *= X.aux;
|
Chris@49
|
175 }
|
Chris@49
|
176
|
Chris@49
|
177
|
Chris@49
|
178
|
Chris@49
|
179 //! @}
|