Chris@49
|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2013 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup glue_times
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 template<bool is_eT_blas_type>
|
Chris@49
|
15 template<typename T1, typename T2>
|
Chris@49
|
16 arma_hot
|
Chris@49
|
17 inline
|
Chris@49
|
18 void
|
Chris@49
|
19 glue_times_redirect2_helper<is_eT_blas_type>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
Chris@49
|
20 {
|
Chris@49
|
21 arma_extra_debug_sigprint();
|
Chris@49
|
22
|
Chris@49
|
23 typedef typename T1::elem_type eT;
|
Chris@49
|
24
|
Chris@49
|
25 const partial_unwrap_check<T1> tmp1(X.A, out);
|
Chris@49
|
26 const partial_unwrap_check<T2> tmp2(X.B, out);
|
Chris@49
|
27
|
Chris@49
|
28 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
29 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
30
|
Chris@49
|
31 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
|
Chris@49
|
32 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
|
Chris@49
|
33
|
Chris@49
|
34 glue_times::apply
|
Chris@49
|
35 <
|
Chris@49
|
36 eT,
|
Chris@49
|
37 partial_unwrap_check<T1>::do_trans,
|
Chris@49
|
38 partial_unwrap_check<T2>::do_trans,
|
Chris@49
|
39 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
|
Chris@49
|
40 >
|
Chris@49
|
41 (out, A, B, alpha);
|
Chris@49
|
42 }
|
Chris@49
|
43
|
Chris@49
|
44
|
Chris@49
|
45
|
Chris@49
|
46 template<typename T1, typename T2>
|
Chris@49
|
47 arma_hot
|
Chris@49
|
48 inline
|
Chris@49
|
49 void
|
Chris@49
|
50 glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
Chris@49
|
51 {
|
Chris@49
|
52 arma_extra_debug_sigprint();
|
Chris@49
|
53
|
Chris@49
|
54 typedef typename T1::elem_type eT;
|
Chris@49
|
55
|
Chris@49
|
56 if(strip_inv<T1>::do_inv == false)
|
Chris@49
|
57 {
|
Chris@49
|
58 const partial_unwrap_check<T1> tmp1(X.A, out);
|
Chris@49
|
59 const partial_unwrap_check<T2> tmp2(X.B, out);
|
Chris@49
|
60
|
Chris@49
|
61 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
62 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
63
|
Chris@49
|
64 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
|
Chris@49
|
65 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
|
Chris@49
|
66
|
Chris@49
|
67 glue_times::apply
|
Chris@49
|
68 <
|
Chris@49
|
69 eT,
|
Chris@49
|
70 partial_unwrap_check<T1>::do_trans,
|
Chris@49
|
71 partial_unwrap_check<T2>::do_trans,
|
Chris@49
|
72 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
|
Chris@49
|
73 >
|
Chris@49
|
74 (out, A, B, alpha);
|
Chris@49
|
75 }
|
Chris@49
|
76 else
|
Chris@49
|
77 {
|
Chris@49
|
78 arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B");
|
Chris@49
|
79
|
Chris@49
|
80 const strip_inv<T1> A_strip(X.A);
|
Chris@49
|
81
|
Chris@49
|
82 Mat<eT> A = A_strip.M;
|
Chris@49
|
83
|
Chris@49
|
84 arma_debug_check( (A.is_square() == false), "inv(): given matrix is not square" );
|
Chris@49
|
85
|
Chris@49
|
86 const unwrap_check<T2> B_tmp(X.B, out);
|
Chris@49
|
87 const Mat<eT>& B = B_tmp.M;
|
Chris@49
|
88
|
Chris@49
|
89 glue_solve::solve_direct( out, A, B, A_strip.slow );
|
Chris@49
|
90 }
|
Chris@49
|
91 }
|
Chris@49
|
92
|
Chris@49
|
93
|
Chris@49
|
94
|
Chris@49
|
95 template<uword N>
|
Chris@49
|
96 template<typename T1, typename T2>
|
Chris@49
|
97 arma_hot
|
Chris@49
|
98 inline
|
Chris@49
|
99 void
|
Chris@49
|
100 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
Chris@49
|
101 {
|
Chris@49
|
102 arma_extra_debug_sigprint();
|
Chris@49
|
103
|
Chris@49
|
104 typedef typename T1::elem_type eT;
|
Chris@49
|
105
|
Chris@49
|
106 const partial_unwrap_check<T1> tmp1(X.A, out);
|
Chris@49
|
107 const partial_unwrap_check<T2> tmp2(X.B, out);
|
Chris@49
|
108
|
Chris@49
|
109 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
110 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
111
|
Chris@49
|
112 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
|
Chris@49
|
113 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
|
Chris@49
|
114
|
Chris@49
|
115 glue_times::apply
|
Chris@49
|
116 <
|
Chris@49
|
117 eT,
|
Chris@49
|
118 partial_unwrap_check<T1>::do_trans,
|
Chris@49
|
119 partial_unwrap_check<T2>::do_trans,
|
Chris@49
|
120 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
|
Chris@49
|
121 >
|
Chris@49
|
122 (out, A, B, alpha);
|
Chris@49
|
123 }
|
Chris@49
|
124
|
Chris@49
|
125
|
Chris@49
|
126
|
Chris@49
|
127 template<typename T1, typename T2>
|
Chris@49
|
128 arma_hot
|
Chris@49
|
129 inline
|
Chris@49
|
130 void
|
Chris@49
|
131 glue_times_redirect<2>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
Chris@49
|
132 {
|
Chris@49
|
133 arma_extra_debug_sigprint();
|
Chris@49
|
134
|
Chris@49
|
135 typedef typename T1::elem_type eT;
|
Chris@49
|
136
|
Chris@49
|
137 glue_times_redirect2_helper< is_supported_blas_type<eT>::value >::apply(out, X);
|
Chris@49
|
138 }
|
Chris@49
|
139
|
Chris@49
|
140
|
Chris@49
|
141
|
Chris@49
|
142 template<typename T1, typename T2, typename T3>
|
Chris@49
|
143 arma_hot
|
Chris@49
|
144 inline
|
Chris@49
|
145 void
|
Chris@49
|
146 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
|
Chris@49
|
147 {
|
Chris@49
|
148 arma_extra_debug_sigprint();
|
Chris@49
|
149
|
Chris@49
|
150 typedef typename T1::elem_type eT;
|
Chris@49
|
151
|
Chris@49
|
152 // TODO: investigate detecting inv(A)*B*C and replacing with solve(A,B)*C
|
Chris@49
|
153 // TODO: investigate detecting A*inv(B)*C and replacing with A*solve(B,C)
|
Chris@49
|
154
|
Chris@49
|
155 // there is exactly 3 objects
|
Chris@49
|
156 // hence we can safely expand X as X.A.A, X.A.B and X.B
|
Chris@49
|
157
|
Chris@49
|
158 const partial_unwrap_check<T1> tmp1(X.A.A, out);
|
Chris@49
|
159 const partial_unwrap_check<T2> tmp2(X.A.B, out);
|
Chris@49
|
160 const partial_unwrap_check<T3> tmp3(X.B, out);
|
Chris@49
|
161
|
Chris@49
|
162 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
163 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
164 const typename partial_unwrap_check<T3>::stored_type& C = tmp3.M;
|
Chris@49
|
165
|
Chris@49
|
166 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times;
|
Chris@49
|
167 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
|
Chris@49
|
168
|
Chris@49
|
169 glue_times::apply
|
Chris@49
|
170 <
|
Chris@49
|
171 eT,
|
Chris@49
|
172 partial_unwrap_check<T1>::do_trans,
|
Chris@49
|
173 partial_unwrap_check<T2>::do_trans,
|
Chris@49
|
174 partial_unwrap_check<T3>::do_trans,
|
Chris@49
|
175 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times)
|
Chris@49
|
176 >
|
Chris@49
|
177 (out, A, B, C, alpha);
|
Chris@49
|
178 }
|
Chris@49
|
179
|
Chris@49
|
180
|
Chris@49
|
181
|
Chris@49
|
182 template<typename T1, typename T2, typename T3, typename T4>
|
Chris@49
|
183 arma_hot
|
Chris@49
|
184 inline
|
Chris@49
|
185 void
|
Chris@49
|
186 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
|
Chris@49
|
187 {
|
Chris@49
|
188 arma_extra_debug_sigprint();
|
Chris@49
|
189
|
Chris@49
|
190 typedef typename T1::elem_type eT;
|
Chris@49
|
191
|
Chris@49
|
192 // there is exactly 4 objects
|
Chris@49
|
193 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
|
Chris@49
|
194
|
Chris@49
|
195 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
|
Chris@49
|
196 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
|
Chris@49
|
197 const partial_unwrap_check<T3> tmp3(X.A.B, out);
|
Chris@49
|
198 const partial_unwrap_check<T4> tmp4(X.B, out);
|
Chris@49
|
199
|
Chris@49
|
200 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
|
Chris@49
|
201 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
|
Chris@49
|
202 const typename partial_unwrap_check<T3>::stored_type& C = tmp3.M;
|
Chris@49
|
203 const typename partial_unwrap_check<T4>::stored_type& D = tmp4.M;
|
Chris@49
|
204
|
Chris@49
|
205 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times || partial_unwrap_check<T4>::do_times;
|
Chris@49
|
206 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
|
Chris@49
|
207
|
Chris@49
|
208 glue_times::apply
|
Chris@49
|
209 <
|
Chris@49
|
210 eT,
|
Chris@49
|
211 partial_unwrap_check<T1>::do_trans,
|
Chris@49
|
212 partial_unwrap_check<T2>::do_trans,
|
Chris@49
|
213 partial_unwrap_check<T3>::do_trans,
|
Chris@49
|
214 partial_unwrap_check<T4>::do_trans,
|
Chris@49
|
215 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times || partial_unwrap_check<T4>::do_times)
|
Chris@49
|
216 >
|
Chris@49
|
217 (out, A, B, C, D, alpha);
|
Chris@49
|
218 }
|
Chris@49
|
219
|
Chris@49
|
220
|
Chris@49
|
221
|
Chris@49
|
222 template<typename T1, typename T2>
|
Chris@49
|
223 arma_hot
|
Chris@49
|
224 inline
|
Chris@49
|
225 void
|
Chris@49
|
226 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
|
Chris@49
|
227 {
|
Chris@49
|
228 arma_extra_debug_sigprint();
|
Chris@49
|
229
|
Chris@49
|
230 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
|
Chris@49
|
231
|
Chris@49
|
232 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
|
Chris@49
|
233
|
Chris@49
|
234 glue_times_redirect<N_mat>::apply(out, X);
|
Chris@49
|
235 }
|
Chris@49
|
236
|
Chris@49
|
237
|
Chris@49
|
238
|
Chris@49
|
239 template<typename T1>
|
Chris@49
|
240 arma_hot
|
Chris@49
|
241 inline
|
Chris@49
|
242 void
|
Chris@49
|
243 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
|
Chris@49
|
244 {
|
Chris@49
|
245 arma_extra_debug_sigprint();
|
Chris@49
|
246
|
Chris@49
|
247 typedef typename T1::elem_type eT;
|
Chris@49
|
248
|
Chris@49
|
249 const unwrap_check<T1> B_tmp(X, out);
|
Chris@49
|
250 const Mat<eT>& B = B_tmp.M;
|
Chris@49
|
251
|
Chris@49
|
252 arma_debug_assert_mul_size(out, B, "matrix multiplication");
|
Chris@49
|
253
|
Chris@49
|
254 const uword out_n_rows = out.n_rows;
|
Chris@49
|
255 const uword out_n_cols = out.n_cols;
|
Chris@49
|
256
|
Chris@49
|
257 if(out_n_cols == B.n_cols)
|
Chris@49
|
258 {
|
Chris@49
|
259 // size of resulting matrix is the same as 'out'
|
Chris@49
|
260
|
Chris@49
|
261 podarray<eT> tmp(out_n_cols);
|
Chris@49
|
262
|
Chris@49
|
263 eT* tmp_rowdata = tmp.memptr();
|
Chris@49
|
264
|
Chris@49
|
265 for(uword row=0; row < out_n_rows; ++row)
|
Chris@49
|
266 {
|
Chris@49
|
267 tmp.copy_row(out, row);
|
Chris@49
|
268
|
Chris@49
|
269 for(uword col=0; col < out_n_cols; ++col)
|
Chris@49
|
270 {
|
Chris@49
|
271 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
|
Chris@49
|
272 }
|
Chris@49
|
273 }
|
Chris@49
|
274
|
Chris@49
|
275 }
|
Chris@49
|
276 else
|
Chris@49
|
277 {
|
Chris@49
|
278 const Mat<eT> tmp(out);
|
Chris@49
|
279
|
Chris@49
|
280 glue_times::apply<eT, false, false, false>(out, tmp, B, eT(1));
|
Chris@49
|
281 }
|
Chris@49
|
282
|
Chris@49
|
283 }
|
Chris@49
|
284
|
Chris@49
|
285
|
Chris@49
|
286
|
Chris@49
|
287 template<typename T1, typename T2>
|
Chris@49
|
288 arma_hot
|
Chris@49
|
289 inline
|
Chris@49
|
290 void
|
Chris@49
|
291 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
|
Chris@49
|
292 {
|
Chris@49
|
293 arma_extra_debug_sigprint();
|
Chris@49
|
294
|
Chris@49
|
295 typedef typename T1::elem_type eT;
|
Chris@49
|
296
|
Chris@49
|
297 const partial_unwrap_check<T1> tmp1(X.A, out);
|
Chris@49
|
298 const partial_unwrap_check<T2> tmp2(X.B, out);
|
Chris@49
|
299
|
Chris@49
|
300 typedef typename partial_unwrap_check<T1>::stored_type TA;
|
Chris@49
|
301 typedef typename partial_unwrap_check<T2>::stored_type TB;
|
Chris@49
|
302
|
Chris@49
|
303 const TA& A = tmp1.M;
|
Chris@49
|
304 const TB& B = tmp2.M;
|
Chris@49
|
305
|
Chris@49
|
306 const bool do_trans_A = partial_unwrap_check<T1>::do_trans;
|
Chris@49
|
307 const bool do_trans_B = partial_unwrap_check<T2>::do_trans;
|
Chris@49
|
308
|
Chris@49
|
309 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || (sign < sword(0));
|
Chris@49
|
310 const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0);
|
Chris@49
|
311
|
Chris@49
|
312 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
|
Chris@49
|
313
|
Chris@49
|
314 const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
|
Chris@49
|
315 const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
|
Chris@49
|
316
|
Chris@49
|
317 arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
|
Chris@49
|
318
|
Chris@49
|
319 if(out.n_elem > 0)
|
Chris@49
|
320 {
|
Chris@49
|
321 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
|
Chris@49
|
322 {
|
Chris@49
|
323 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
324 {
|
Chris@49
|
325 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
326 }
|
Chris@49
|
327 else
|
Chris@49
|
328 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
329 {
|
Chris@49
|
330 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
331 }
|
Chris@49
|
332 else
|
Chris@49
|
333 {
|
Chris@49
|
334 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
335 }
|
Chris@49
|
336 }
|
Chris@49
|
337 else
|
Chris@49
|
338 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
|
Chris@49
|
339 {
|
Chris@49
|
340 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
341 {
|
Chris@49
|
342 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
343 }
|
Chris@49
|
344 else
|
Chris@49
|
345 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
346 {
|
Chris@49
|
347 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
348 }
|
Chris@49
|
349 else
|
Chris@49
|
350 {
|
Chris@49
|
351 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
352 }
|
Chris@49
|
353 }
|
Chris@49
|
354 else
|
Chris@49
|
355 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
|
Chris@49
|
356 {
|
Chris@49
|
357 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
358 {
|
Chris@49
|
359 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
360 }
|
Chris@49
|
361 else
|
Chris@49
|
362 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
363 {
|
Chris@49
|
364 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
365 }
|
Chris@49
|
366 else
|
Chris@49
|
367 {
|
Chris@49
|
368 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
369 }
|
Chris@49
|
370 }
|
Chris@49
|
371 else
|
Chris@49
|
372 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
|
Chris@49
|
373 {
|
Chris@49
|
374 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
375 {
|
Chris@49
|
376 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
377 }
|
Chris@49
|
378 else
|
Chris@49
|
379 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
380 {
|
Chris@49
|
381 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
382 }
|
Chris@49
|
383 else
|
Chris@49
|
384 {
|
Chris@49
|
385 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
386 }
|
Chris@49
|
387 }
|
Chris@49
|
388 else
|
Chris@49
|
389 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
|
Chris@49
|
390 {
|
Chris@49
|
391 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
392 {
|
Chris@49
|
393 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
394 }
|
Chris@49
|
395 else
|
Chris@49
|
396 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
397 {
|
Chris@49
|
398 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
399 }
|
Chris@49
|
400 else
|
Chris@49
|
401 {
|
Chris@49
|
402 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
403 }
|
Chris@49
|
404 }
|
Chris@49
|
405 else
|
Chris@49
|
406 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
|
Chris@49
|
407 {
|
Chris@49
|
408 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
409 {
|
Chris@49
|
410 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
411 }
|
Chris@49
|
412 else
|
Chris@49
|
413 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
414 {
|
Chris@49
|
415 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
416 }
|
Chris@49
|
417 else
|
Chris@49
|
418 {
|
Chris@49
|
419 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
420 }
|
Chris@49
|
421 }
|
Chris@49
|
422 else
|
Chris@49
|
423 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
|
Chris@49
|
424 {
|
Chris@49
|
425 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
426 {
|
Chris@49
|
427 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
428 }
|
Chris@49
|
429 else
|
Chris@49
|
430 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
431 {
|
Chris@49
|
432 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
433 }
|
Chris@49
|
434 else
|
Chris@49
|
435 {
|
Chris@49
|
436 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
437 }
|
Chris@49
|
438 }
|
Chris@49
|
439 else
|
Chris@49
|
440 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
|
Chris@49
|
441 {
|
Chris@49
|
442 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
443 {
|
Chris@49
|
444 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
|
Chris@49
|
445 }
|
Chris@49
|
446 else
|
Chris@49
|
447 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
448 {
|
Chris@49
|
449 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
|
Chris@49
|
450 }
|
Chris@49
|
451 else
|
Chris@49
|
452 {
|
Chris@49
|
453 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
|
Chris@49
|
454 }
|
Chris@49
|
455 }
|
Chris@49
|
456 }
|
Chris@49
|
457
|
Chris@49
|
458
|
Chris@49
|
459 }
|
Chris@49
|
460
|
Chris@49
|
461
|
Chris@49
|
462
|
Chris@49
|
463 template<typename eT, const bool do_trans_A, const bool do_trans_B, typename TA, typename TB>
|
Chris@49
|
464 arma_inline
|
Chris@49
|
465 uword
|
Chris@49
|
466 glue_times::mul_storage_cost(const TA& A, const TB& B)
|
Chris@49
|
467 {
|
Chris@49
|
468 const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols );
|
Chris@49
|
469 const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows );
|
Chris@49
|
470
|
Chris@49
|
471 return final_A_n_rows * final_B_n_cols;
|
Chris@49
|
472 }
|
Chris@49
|
473
|
Chris@49
|
474
|
Chris@49
|
475
|
Chris@49
|
476 template
|
Chris@49
|
477 <
|
Chris@49
|
478 typename eT,
|
Chris@49
|
479 const bool do_trans_A,
|
Chris@49
|
480 const bool do_trans_B,
|
Chris@49
|
481 const bool use_alpha,
|
Chris@49
|
482 typename TA,
|
Chris@49
|
483 typename TB
|
Chris@49
|
484 >
|
Chris@49
|
485 arma_hot
|
Chris@49
|
486 inline
|
Chris@49
|
487 void
|
Chris@49
|
488 glue_times::apply
|
Chris@49
|
489 (
|
Chris@49
|
490 Mat<eT>& out,
|
Chris@49
|
491 const TA& A,
|
Chris@49
|
492 const TB& B,
|
Chris@49
|
493 const eT alpha
|
Chris@49
|
494 )
|
Chris@49
|
495 {
|
Chris@49
|
496 arma_extra_debug_sigprint();
|
Chris@49
|
497
|
Chris@49
|
498 //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
|
Chris@49
|
499 arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
|
Chris@49
|
500
|
Chris@49
|
501 const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
|
Chris@49
|
502 const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
|
Chris@49
|
503
|
Chris@49
|
504 out.set_size(final_n_rows, final_n_cols);
|
Chris@49
|
505
|
Chris@49
|
506 if( (A.n_elem > 0) && (B.n_elem > 0) )
|
Chris@49
|
507 {
|
Chris@49
|
508 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
|
Chris@49
|
509 {
|
Chris@49
|
510 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
511 {
|
Chris@49
|
512 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
|
Chris@49
|
513 }
|
Chris@49
|
514 else
|
Chris@49
|
515 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
516 {
|
Chris@49
|
517 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
|
Chris@49
|
518 }
|
Chris@49
|
519 else
|
Chris@49
|
520 {
|
Chris@49
|
521 gemm<false, false, false, false>::apply(out, A, B);
|
Chris@49
|
522 }
|
Chris@49
|
523 }
|
Chris@49
|
524 else
|
Chris@49
|
525 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
|
Chris@49
|
526 {
|
Chris@49
|
527 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
528 {
|
Chris@49
|
529 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
Chris@49
|
530 }
|
Chris@49
|
531 else
|
Chris@49
|
532 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
533 {
|
Chris@49
|
534 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
Chris@49
|
535 }
|
Chris@49
|
536 else
|
Chris@49
|
537 {
|
Chris@49
|
538 gemm<false, false, true, false>::apply(out, A, B, alpha);
|
Chris@49
|
539 }
|
Chris@49
|
540 }
|
Chris@49
|
541 else
|
Chris@49
|
542 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
|
Chris@49
|
543 {
|
Chris@49
|
544 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
545 {
|
Chris@49
|
546 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
|
Chris@49
|
547 }
|
Chris@49
|
548 else
|
Chris@49
|
549 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
550 {
|
Chris@49
|
551 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
|
Chris@49
|
552 }
|
Chris@49
|
553 else
|
Chris@49
|
554 {
|
Chris@49
|
555 gemm<true, false, false, false>::apply(out, A, B);
|
Chris@49
|
556 }
|
Chris@49
|
557 }
|
Chris@49
|
558 else
|
Chris@49
|
559 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
|
Chris@49
|
560 {
|
Chris@49
|
561 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
562 {
|
Chris@49
|
563 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
Chris@49
|
564 }
|
Chris@49
|
565 else
|
Chris@49
|
566 if( (B.n_cols == 1) || (TB::is_col) )
|
Chris@49
|
567 {
|
Chris@49
|
568 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
Chris@49
|
569 }
|
Chris@49
|
570 else
|
Chris@49
|
571 {
|
Chris@49
|
572 gemm<true, false, true, false>::apply(out, A, B, alpha);
|
Chris@49
|
573 }
|
Chris@49
|
574 }
|
Chris@49
|
575 else
|
Chris@49
|
576 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
|
Chris@49
|
577 {
|
Chris@49
|
578 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
579 {
|
Chris@49
|
580 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
|
Chris@49
|
581 }
|
Chris@49
|
582 else
|
Chris@49
|
583 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
584 {
|
Chris@49
|
585 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
|
Chris@49
|
586 }
|
Chris@49
|
587 else
|
Chris@49
|
588 {
|
Chris@49
|
589 gemm<false, true, false, false>::apply(out, A, B);
|
Chris@49
|
590 }
|
Chris@49
|
591 }
|
Chris@49
|
592 else
|
Chris@49
|
593 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
|
Chris@49
|
594 {
|
Chris@49
|
595 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
596 {
|
Chris@49
|
597 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
Chris@49
|
598 }
|
Chris@49
|
599 else
|
Chris@49
|
600 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
601 {
|
Chris@49
|
602 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
Chris@49
|
603 }
|
Chris@49
|
604 else
|
Chris@49
|
605 {
|
Chris@49
|
606 gemm<false, true, true, false>::apply(out, A, B, alpha);
|
Chris@49
|
607 }
|
Chris@49
|
608 }
|
Chris@49
|
609 else
|
Chris@49
|
610 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
|
Chris@49
|
611 {
|
Chris@49
|
612 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
613 {
|
Chris@49
|
614 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
|
Chris@49
|
615 }
|
Chris@49
|
616 else
|
Chris@49
|
617 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
618 {
|
Chris@49
|
619 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
|
Chris@49
|
620 }
|
Chris@49
|
621 else
|
Chris@49
|
622 {
|
Chris@49
|
623 gemm<true, true, false, false>::apply(out, A, B);
|
Chris@49
|
624 }
|
Chris@49
|
625 }
|
Chris@49
|
626 else
|
Chris@49
|
627 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
|
Chris@49
|
628 {
|
Chris@49
|
629 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
|
Chris@49
|
630 {
|
Chris@49
|
631 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
|
Chris@49
|
632 }
|
Chris@49
|
633 else
|
Chris@49
|
634 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
|
Chris@49
|
635 {
|
Chris@49
|
636 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
|
Chris@49
|
637 }
|
Chris@49
|
638 else
|
Chris@49
|
639 {
|
Chris@49
|
640 gemm<true, true, true, false>::apply(out, A, B, alpha);
|
Chris@49
|
641 }
|
Chris@49
|
642 }
|
Chris@49
|
643 }
|
Chris@49
|
644 else
|
Chris@49
|
645 {
|
Chris@49
|
646 out.zeros();
|
Chris@49
|
647 }
|
Chris@49
|
648 }
|
Chris@49
|
649
|
Chris@49
|
650
|
Chris@49
|
651
|
Chris@49
|
652 template
|
Chris@49
|
653 <
|
Chris@49
|
654 typename eT,
|
Chris@49
|
655 const bool do_trans_A,
|
Chris@49
|
656 const bool do_trans_B,
|
Chris@49
|
657 const bool do_trans_C,
|
Chris@49
|
658 const bool use_alpha,
|
Chris@49
|
659 typename TA,
|
Chris@49
|
660 typename TB,
|
Chris@49
|
661 typename TC
|
Chris@49
|
662 >
|
Chris@49
|
663 arma_hot
|
Chris@49
|
664 inline
|
Chris@49
|
665 void
|
Chris@49
|
666 glue_times::apply
|
Chris@49
|
667 (
|
Chris@49
|
668 Mat<eT>& out,
|
Chris@49
|
669 const TA& A,
|
Chris@49
|
670 const TB& B,
|
Chris@49
|
671 const TC& C,
|
Chris@49
|
672 const eT alpha
|
Chris@49
|
673 )
|
Chris@49
|
674 {
|
Chris@49
|
675 arma_extra_debug_sigprint();
|
Chris@49
|
676
|
Chris@49
|
677 Mat<eT> tmp;
|
Chris@49
|
678
|
Chris@49
|
679 const uword storage_cost_AB = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_B>(A, B);
|
Chris@49
|
680 const uword storage_cost_BC = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_C>(B, C);
|
Chris@49
|
681
|
Chris@49
|
682 if(storage_cost_AB <= storage_cost_BC)
|
Chris@49
|
683 {
|
Chris@49
|
684 // out = (A*B)*C
|
Chris@49
|
685
|
Chris@49
|
686 glue_times::apply<eT, do_trans_A, do_trans_B, use_alpha>(tmp, A, B, alpha);
|
Chris@49
|
687 glue_times::apply<eT, false, do_trans_C, false >(out, tmp, C, eT(0));
|
Chris@49
|
688 }
|
Chris@49
|
689 else
|
Chris@49
|
690 {
|
Chris@49
|
691 // out = A*(B*C)
|
Chris@49
|
692
|
Chris@49
|
693 glue_times::apply<eT, do_trans_B, do_trans_C, use_alpha>(tmp, B, C, alpha);
|
Chris@49
|
694 glue_times::apply<eT, do_trans_A, false, false >(out, A, tmp, eT(0));
|
Chris@49
|
695 }
|
Chris@49
|
696 }
|
Chris@49
|
697
|
Chris@49
|
698
|
Chris@49
|
699
|
Chris@49
|
700 template
|
Chris@49
|
701 <
|
Chris@49
|
702 typename eT,
|
Chris@49
|
703 const bool do_trans_A,
|
Chris@49
|
704 const bool do_trans_B,
|
Chris@49
|
705 const bool do_trans_C,
|
Chris@49
|
706 const bool do_trans_D,
|
Chris@49
|
707 const bool use_alpha,
|
Chris@49
|
708 typename TA,
|
Chris@49
|
709 typename TB,
|
Chris@49
|
710 typename TC,
|
Chris@49
|
711 typename TD
|
Chris@49
|
712 >
|
Chris@49
|
713 arma_hot
|
Chris@49
|
714 inline
|
Chris@49
|
715 void
|
Chris@49
|
716 glue_times::apply
|
Chris@49
|
717 (
|
Chris@49
|
718 Mat<eT>& out,
|
Chris@49
|
719 const TA& A,
|
Chris@49
|
720 const TB& B,
|
Chris@49
|
721 const TC& C,
|
Chris@49
|
722 const TD& D,
|
Chris@49
|
723 const eT alpha
|
Chris@49
|
724 )
|
Chris@49
|
725 {
|
Chris@49
|
726 arma_extra_debug_sigprint();
|
Chris@49
|
727
|
Chris@49
|
728 Mat<eT> tmp;
|
Chris@49
|
729
|
Chris@49
|
730 const uword storage_cost_AC = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_C>(A, C);
|
Chris@49
|
731 const uword storage_cost_BD = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_D>(B, D);
|
Chris@49
|
732
|
Chris@49
|
733 if(storage_cost_AC <= storage_cost_BD)
|
Chris@49
|
734 {
|
Chris@49
|
735 // out = (A*B*C)*D
|
Chris@49
|
736
|
Chris@49
|
737 glue_times::apply<eT, do_trans_A, do_trans_B, do_trans_C, use_alpha>(tmp, A, B, C, alpha);
|
Chris@49
|
738
|
Chris@49
|
739 glue_times::apply<eT, false, do_trans_D, false>(out, tmp, D, eT(0));
|
Chris@49
|
740 }
|
Chris@49
|
741 else
|
Chris@49
|
742 {
|
Chris@49
|
743 // out = A*(B*C*D)
|
Chris@49
|
744
|
Chris@49
|
745 glue_times::apply<eT, do_trans_B, do_trans_C, do_trans_D, use_alpha>(tmp, B, C, D, alpha);
|
Chris@49
|
746
|
Chris@49
|
747 glue_times::apply<eT, do_trans_A, false, false>(out, A, tmp, eT(0));
|
Chris@49
|
748 }
|
Chris@49
|
749 }
|
Chris@49
|
750
|
Chris@49
|
751
|
Chris@49
|
752
|
Chris@49
|
753 //
|
Chris@49
|
754 // glue_times_diag
|
Chris@49
|
755
|
Chris@49
|
756
|
Chris@49
|
757 template<typename T1, typename T2>
|
Chris@49
|
758 arma_hot
|
Chris@49
|
759 inline
|
Chris@49
|
760 void
|
Chris@49
|
761 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
|
Chris@49
|
762 {
|
Chris@49
|
763 arma_extra_debug_sigprint();
|
Chris@49
|
764
|
Chris@49
|
765 typedef typename T1::elem_type eT;
|
Chris@49
|
766
|
Chris@49
|
767 const strip_diagmat<T1> S1(X.A);
|
Chris@49
|
768 const strip_diagmat<T2> S2(X.B);
|
Chris@49
|
769
|
Chris@49
|
770 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
|
Chris@49
|
771 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
|
Chris@49
|
772
|
Chris@49
|
773 if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == false) )
|
Chris@49
|
774 {
|
Chris@49
|
775 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
|
Chris@49
|
776
|
Chris@49
|
777 const unwrap_check<T2> tmp(X.B, out);
|
Chris@49
|
778 const Mat<eT>& B = tmp.M;
|
Chris@49
|
779
|
Chris@49
|
780 const uword A_n_elem = A.n_elem;
|
Chris@49
|
781 const uword B_n_rows = B.n_rows;
|
Chris@49
|
782 const uword B_n_cols = B.n_cols;
|
Chris@49
|
783
|
Chris@49
|
784 arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_rows, B_n_cols, "matrix multiplication");
|
Chris@49
|
785
|
Chris@49
|
786 out.set_size(A_n_elem, B_n_cols);
|
Chris@49
|
787
|
Chris@49
|
788 for(uword col=0; col < B_n_cols; ++col)
|
Chris@49
|
789 {
|
Chris@49
|
790 eT* out_coldata = out.colptr(col);
|
Chris@49
|
791 const eT* B_coldata = B.colptr(col);
|
Chris@49
|
792
|
Chris@49
|
793 uword i,j;
|
Chris@49
|
794 for(i=0, j=1; j < B_n_rows; i+=2, j+=2)
|
Chris@49
|
795 {
|
Chris@49
|
796 eT tmp_i = A[i];
|
Chris@49
|
797 eT tmp_j = A[j];
|
Chris@49
|
798
|
Chris@49
|
799 tmp_i *= B_coldata[i];
|
Chris@49
|
800 tmp_j *= B_coldata[j];
|
Chris@49
|
801
|
Chris@49
|
802 out_coldata[i] = tmp_i;
|
Chris@49
|
803 out_coldata[j] = tmp_j;
|
Chris@49
|
804 }
|
Chris@49
|
805
|
Chris@49
|
806 if(i < B_n_rows)
|
Chris@49
|
807 {
|
Chris@49
|
808 out_coldata[i] = A[i] * B_coldata[i];
|
Chris@49
|
809 }
|
Chris@49
|
810 }
|
Chris@49
|
811 }
|
Chris@49
|
812 else
|
Chris@49
|
813 if( (strip_diagmat<T1>::do_diagmat == false) && (strip_diagmat<T2>::do_diagmat == true) )
|
Chris@49
|
814 {
|
Chris@49
|
815 const unwrap_check<T1> tmp(X.A, out);
|
Chris@49
|
816 const Mat<eT>& A = tmp.M;
|
Chris@49
|
817
|
Chris@49
|
818 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
|
Chris@49
|
819
|
Chris@49
|
820 const uword A_n_rows = A.n_rows;
|
Chris@49
|
821 const uword A_n_cols = A.n_cols;
|
Chris@49
|
822 const uword B_n_elem = B.n_elem;
|
Chris@49
|
823
|
Chris@49
|
824 arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_elem, B_n_elem, "matrix multiplication");
|
Chris@49
|
825
|
Chris@49
|
826 out.set_size(A_n_rows, B_n_elem);
|
Chris@49
|
827
|
Chris@49
|
828 for(uword col=0; col < A_n_cols; ++col)
|
Chris@49
|
829 {
|
Chris@49
|
830 const eT val = B[col];
|
Chris@49
|
831
|
Chris@49
|
832 eT* out_coldata = out.colptr(col);
|
Chris@49
|
833 const eT* A_coldata = A.colptr(col);
|
Chris@49
|
834
|
Chris@49
|
835 uword i,j;
|
Chris@49
|
836 for(i=0, j=1; j < A_n_rows; i+=2, j+=2)
|
Chris@49
|
837 {
|
Chris@49
|
838 const eT tmp_i = A_coldata[i] * val;
|
Chris@49
|
839 const eT tmp_j = A_coldata[j] * val;
|
Chris@49
|
840
|
Chris@49
|
841 out_coldata[i] = tmp_i;
|
Chris@49
|
842 out_coldata[j] = tmp_j;
|
Chris@49
|
843 }
|
Chris@49
|
844
|
Chris@49
|
845 if(i < A_n_rows)
|
Chris@49
|
846 {
|
Chris@49
|
847 out_coldata[i] = A_coldata[i] * val;
|
Chris@49
|
848 }
|
Chris@49
|
849 }
|
Chris@49
|
850 }
|
Chris@49
|
851 else
|
Chris@49
|
852 if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == true) )
|
Chris@49
|
853 {
|
Chris@49
|
854 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
|
Chris@49
|
855 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
|
Chris@49
|
856
|
Chris@49
|
857 const uword A_n_elem = A.n_elem;
|
Chris@49
|
858 const uword B_n_elem = B.n_elem;
|
Chris@49
|
859
|
Chris@49
|
860 arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_elem, B_n_elem, "matrix multiplication");
|
Chris@49
|
861
|
Chris@49
|
862 out.zeros(A_n_elem, A_n_elem);
|
Chris@49
|
863
|
Chris@49
|
864 for(uword i=0; i < A_n_elem; ++i)
|
Chris@49
|
865 {
|
Chris@49
|
866 out.at(i,i) = A[i] * B[i];
|
Chris@49
|
867 }
|
Chris@49
|
868 }
|
Chris@49
|
869 }
|
Chris@49
|
870
|
Chris@49
|
871
|
Chris@49
|
872
|
Chris@49
|
873 //! @}
|