Chris@49
|
1 // Copyright (C) 2008-2012 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2012 Conrad Sanderson
|
Chris@49
|
3 // Copyright (C) 2012 Ryan Curtin
|
Chris@49
|
4 //
|
Chris@49
|
5 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
6 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
7 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
8
|
Chris@49
|
9
|
Chris@49
|
10 //! \addtogroup operator_schur
|
Chris@49
|
11 //! @{
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 // operator %, which we define it to do a schur product (element-wise multiplication)
|
Chris@49
|
15
|
Chris@49
|
16
|
Chris@49
|
17 //! element-wise multiplication of user-accessible Armadillo objects with same element type
|
Chris@49
|
18 template<typename T1, typename T2>
|
Chris@49
|
19 arma_inline
|
Chris@49
|
20 typename
|
Chris@49
|
21 enable_if2
|
Chris@49
|
22 <
|
Chris@49
|
23 is_arma_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value,
|
Chris@49
|
24 const eGlue<T1, T2, eglue_schur>
|
Chris@49
|
25 >::result
|
Chris@49
|
26 operator%
|
Chris@49
|
27 (
|
Chris@49
|
28 const T1& X,
|
Chris@49
|
29 const T2& Y
|
Chris@49
|
30 )
|
Chris@49
|
31 {
|
Chris@49
|
32 arma_extra_debug_sigprint();
|
Chris@49
|
33
|
Chris@49
|
34 return eGlue<T1, T2, eglue_schur>(X, Y);
|
Chris@49
|
35 }
|
Chris@49
|
36
|
Chris@49
|
37
|
Chris@49
|
38
|
Chris@49
|
39 //! element-wise multiplication of user-accessible Armadillo objects with different element types
|
Chris@49
|
40 template<typename T1, typename T2>
|
Chris@49
|
41 inline
|
Chris@49
|
42 typename
|
Chris@49
|
43 enable_if2
|
Chris@49
|
44 <
|
Chris@49
|
45 (is_arma_type<T1>::value && is_arma_type<T2>::value && (is_same_type<typename T1::elem_type, typename T2::elem_type>::value == false)),
|
Chris@49
|
46 const mtGlue<typename promote_type<typename T1::elem_type, typename T2::elem_type>::result, T1, T2, glue_mixed_schur>
|
Chris@49
|
47 >::result
|
Chris@49
|
48 operator%
|
Chris@49
|
49 (
|
Chris@49
|
50 const T1& X,
|
Chris@49
|
51 const T2& Y
|
Chris@49
|
52 )
|
Chris@49
|
53 {
|
Chris@49
|
54 arma_extra_debug_sigprint();
|
Chris@49
|
55
|
Chris@49
|
56 typedef typename T1::elem_type eT1;
|
Chris@49
|
57 typedef typename T2::elem_type eT2;
|
Chris@49
|
58
|
Chris@49
|
59 typedef typename promote_type<eT1,eT2>::result out_eT;
|
Chris@49
|
60
|
Chris@49
|
61 promote_type<eT1,eT2>::check();
|
Chris@49
|
62
|
Chris@49
|
63 return mtGlue<out_eT, T1, T2, glue_mixed_schur>( X, Y );
|
Chris@49
|
64 }
|
Chris@49
|
65
|
Chris@49
|
66
|
Chris@49
|
67
|
Chris@49
|
68 //! element-wise multiplication of two sparse matrices
|
Chris@49
|
69 template<typename T1, typename T2>
|
Chris@49
|
70 inline
|
Chris@49
|
71 typename
|
Chris@49
|
72 enable_if2
|
Chris@49
|
73 <
|
Chris@49
|
74 (is_arma_sparse_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
75 SpMat<typename T1::elem_type>
|
Chris@49
|
76 >::result
|
Chris@49
|
77 operator%
|
Chris@49
|
78 (
|
Chris@49
|
79 const SpBase<typename T1::elem_type, T1>& x,
|
Chris@49
|
80 const SpBase<typename T2::elem_type, T2>& y
|
Chris@49
|
81 )
|
Chris@49
|
82 {
|
Chris@49
|
83 arma_extra_debug_sigprint();
|
Chris@49
|
84
|
Chris@49
|
85 typedef typename T1::elem_type eT;
|
Chris@49
|
86
|
Chris@49
|
87 const SpProxy<T1> pa(x.get_ref());
|
Chris@49
|
88 const SpProxy<T2> pb(y.get_ref());
|
Chris@49
|
89
|
Chris@49
|
90 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
|
Chris@49
|
91
|
Chris@49
|
92 SpMat<typename T1::elem_type> result(pa.get_n_rows(), pa.get_n_cols());
|
Chris@49
|
93
|
Chris@49
|
94 if( (pa.get_n_nonzero() != 0) && (pb.get_n_nonzero() != 0) )
|
Chris@49
|
95 {
|
Chris@49
|
96 // Resize memory to correct size.
|
Chris@49
|
97 result.mem_resize(n_unique(x, y, op_n_unique_mul()));
|
Chris@49
|
98
|
Chris@49
|
99 // Now iterate across both matrices.
|
Chris@49
|
100 typename SpProxy<T1>::const_iterator_type x_it = pa.begin();
|
Chris@49
|
101 typename SpProxy<T2>::const_iterator_type y_it = pb.begin();
|
Chris@49
|
102
|
Chris@49
|
103 typename SpProxy<T1>::const_iterator_type x_end = pa.end();
|
Chris@49
|
104 typename SpProxy<T2>::const_iterator_type y_end = pb.end();
|
Chris@49
|
105
|
Chris@49
|
106 uword cur_val = 0;
|
Chris@49
|
107 while((x_it != x_end) || (y_it != y_end))
|
Chris@49
|
108 {
|
Chris@49
|
109 if(x_it == y_it)
|
Chris@49
|
110 {
|
Chris@49
|
111 const eT val = (*x_it) * (*y_it);
|
Chris@49
|
112
|
Chris@49
|
113 if (val != eT(0))
|
Chris@49
|
114 {
|
Chris@49
|
115 access::rw(result.values[cur_val]) = val;
|
Chris@49
|
116 access::rw(result.row_indices[cur_val]) = x_it.row();
|
Chris@49
|
117 ++access::rw(result.col_ptrs[x_it.col() + 1]);
|
Chris@49
|
118 ++cur_val;
|
Chris@49
|
119 }
|
Chris@49
|
120
|
Chris@49
|
121 ++x_it;
|
Chris@49
|
122 ++y_it;
|
Chris@49
|
123 }
|
Chris@49
|
124 else
|
Chris@49
|
125 {
|
Chris@49
|
126 const uword x_it_row = x_it.row();
|
Chris@49
|
127 const uword x_it_col = x_it.col();
|
Chris@49
|
128
|
Chris@49
|
129 const uword y_it_row = y_it.row();
|
Chris@49
|
130 const uword y_it_col = y_it.col();
|
Chris@49
|
131
|
Chris@49
|
132 if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end
|
Chris@49
|
133 {
|
Chris@49
|
134 ++x_it;
|
Chris@49
|
135 }
|
Chris@49
|
136 else
|
Chris@49
|
137 {
|
Chris@49
|
138 ++y_it;
|
Chris@49
|
139 }
|
Chris@49
|
140 }
|
Chris@49
|
141 }
|
Chris@49
|
142
|
Chris@49
|
143 // Fix column pointers to be cumulative.
|
Chris@49
|
144 for(uword c = 1; c <= result.n_cols; ++c)
|
Chris@49
|
145 {
|
Chris@49
|
146 access::rw(result.col_ptrs[c]) += result.col_ptrs[c - 1];
|
Chris@49
|
147 }
|
Chris@49
|
148 }
|
Chris@49
|
149
|
Chris@49
|
150 return result;
|
Chris@49
|
151 }
|
Chris@49
|
152
|
Chris@49
|
153
|
Chris@49
|
154
|
Chris@49
|
155 //! element-wise multiplication of one dense and one sparse object
|
Chris@49
|
156 template<typename T1, typename T2>
|
Chris@49
|
157 inline
|
Chris@49
|
158 typename
|
Chris@49
|
159 enable_if2
|
Chris@49
|
160 <
|
Chris@49
|
161 (is_arma_type<T1>::value && is_arma_sparse_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
162 SpMat<typename T1::elem_type>
|
Chris@49
|
163 >::result
|
Chris@49
|
164 operator%
|
Chris@49
|
165 (
|
Chris@49
|
166 const T1& x,
|
Chris@49
|
167 const T2& y
|
Chris@49
|
168 )
|
Chris@49
|
169 {
|
Chris@49
|
170 arma_extra_debug_sigprint();
|
Chris@49
|
171
|
Chris@49
|
172 typedef typename T1::elem_type eT;
|
Chris@49
|
173
|
Chris@49
|
174 const Proxy<T1> pa(x);
|
Chris@49
|
175 const SpProxy<T2> pb(y);
|
Chris@49
|
176
|
Chris@49
|
177 arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication");
|
Chris@49
|
178
|
Chris@49
|
179 SpMat<eT> result(pa.get_n_rows(), pa.get_n_cols());
|
Chris@49
|
180
|
Chris@49
|
181 // count new size
|
Chris@49
|
182 uword new_n_nonzero = 0;
|
Chris@49
|
183
|
Chris@49
|
184 typename SpProxy<T2>::const_iterator_type it = pb.begin();
|
Chris@49
|
185 typename SpProxy<T2>::const_iterator_type it_end = pb.end();
|
Chris@49
|
186
|
Chris@49
|
187 while(it != it_end)
|
Chris@49
|
188 {
|
Chris@49
|
189 if( ((*it) * pa.at(it.row(), it.col())) != eT(0) )
|
Chris@49
|
190 {
|
Chris@49
|
191 ++new_n_nonzero;
|
Chris@49
|
192 }
|
Chris@49
|
193
|
Chris@49
|
194 ++it;
|
Chris@49
|
195 }
|
Chris@49
|
196
|
Chris@49
|
197 // Resize memory accordingly.
|
Chris@49
|
198 result.mem_resize(new_n_nonzero);
|
Chris@49
|
199
|
Chris@49
|
200 uword cur_val = 0;
|
Chris@49
|
201
|
Chris@49
|
202 typename SpProxy<T2>::const_iterator_type it2 = pb.begin();
|
Chris@49
|
203
|
Chris@49
|
204 while(it2 != it_end)
|
Chris@49
|
205 {
|
Chris@49
|
206 const uword it2_row = it2.row();
|
Chris@49
|
207 const uword it2_col = it2.col();
|
Chris@49
|
208
|
Chris@49
|
209 const eT val = (*it2) * pa.at(it2_row, it2_col);
|
Chris@49
|
210
|
Chris@49
|
211 if(val != eT(0))
|
Chris@49
|
212 {
|
Chris@49
|
213 access::rw(result.values[cur_val]) = val;
|
Chris@49
|
214 access::rw(result.row_indices[cur_val]) = it2_row;
|
Chris@49
|
215 ++access::rw(result.col_ptrs[it2_col + 1]);
|
Chris@49
|
216 ++cur_val;
|
Chris@49
|
217 }
|
Chris@49
|
218
|
Chris@49
|
219 ++it2;
|
Chris@49
|
220 }
|
Chris@49
|
221
|
Chris@49
|
222 // Fix column pointers.
|
Chris@49
|
223 for(uword c = 1; c <= result.n_cols; ++c)
|
Chris@49
|
224 {
|
Chris@49
|
225 access::rw(result.col_ptrs[c]) += result.col_ptrs[c - 1];
|
Chris@49
|
226 }
|
Chris@49
|
227
|
Chris@49
|
228 return result;
|
Chris@49
|
229 }
|
Chris@49
|
230
|
Chris@49
|
231
|
Chris@49
|
232
|
Chris@49
|
233 //! element-wise multiplication of one sparse and one dense object
|
Chris@49
|
234 template<typename T1, typename T2>
|
Chris@49
|
235 inline
|
Chris@49
|
236 typename
|
Chris@49
|
237 enable_if2
|
Chris@49
|
238 <
|
Chris@49
|
239 (is_arma_sparse_type<T1>::value && is_arma_type<T2>::value && is_same_type<typename T1::elem_type, typename T2::elem_type>::value),
|
Chris@49
|
240 SpMat<typename T1::elem_type>
|
Chris@49
|
241 >::result
|
Chris@49
|
242 operator%
|
Chris@49
|
243 (
|
Chris@49
|
244 const T1& x,
|
Chris@49
|
245 const T2& y
|
Chris@49
|
246 )
|
Chris@49
|
247 {
|
Chris@49
|
248 arma_extra_debug_sigprint();
|
Chris@49
|
249
|
Chris@49
|
250 // This operation is commutative.
|
Chris@49
|
251 return (y % x);
|
Chris@49
|
252 }
|
Chris@49
|
253
|
Chris@49
|
254
|
Chris@49
|
255
|
Chris@49
|
256 //! @}
|