Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/glue_times_meat.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup glue_times | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 template<uword N> | |
20 template<typename T1, typename T2> | |
21 inline | |
22 void | |
23 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X) | |
24 { | |
25 arma_extra_debug_sigprint(); | |
26 | |
27 typedef typename T1::elem_type eT; | |
28 | |
29 const partial_unwrap_check<T1> tmp1(X.A, out); | |
30 const partial_unwrap_check<T2> tmp2(X.B, out); | |
31 | |
32 const Mat<eT>& A = tmp1.M; | |
33 const Mat<eT>& B = tmp2.M; | |
34 | |
35 const bool do_trans_A = tmp1.do_trans; | |
36 const bool do_trans_B = tmp2.do_trans; | |
37 | |
38 const bool use_alpha = tmp1.do_times || tmp2.do_times; | |
39 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0); | |
40 | |
41 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); | |
42 } | |
43 | |
44 | |
45 | |
46 template<typename T1, typename T2, typename T3> | |
47 inline | |
48 void | |
49 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X) | |
50 { | |
51 arma_extra_debug_sigprint(); | |
52 | |
53 typedef typename T1::elem_type eT; | |
54 | |
55 // there is exactly 3 objects | |
56 // hence we can safely expand X as X.A.A, X.A.B and X.B | |
57 | |
58 const partial_unwrap_check<T1> tmp1(X.A.A, out); | |
59 const partial_unwrap_check<T2> tmp2(X.A.B, out); | |
60 const partial_unwrap_check<T3> tmp3(X.B, out); | |
61 | |
62 const Mat<eT>& A = tmp1.M; | |
63 const Mat<eT>& B = tmp2.M; | |
64 const Mat<eT>& C = tmp3.M; | |
65 | |
66 const bool do_trans_A = tmp1.do_trans; | |
67 const bool do_trans_B = tmp2.do_trans; | |
68 const bool do_trans_C = tmp3.do_trans; | |
69 | |
70 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times; | |
71 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0); | |
72 | |
73 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); | |
74 } | |
75 | |
76 | |
77 | |
78 template<typename T1, typename T2, typename T3, typename T4> | |
79 inline | |
80 void | |
81 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X) | |
82 { | |
83 arma_extra_debug_sigprint(); | |
84 | |
85 typedef typename T1::elem_type eT; | |
86 | |
87 // there is exactly 4 objects | |
88 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B | |
89 | |
90 const partial_unwrap_check<T1> tmp1(X.A.A.A, out); | |
91 const partial_unwrap_check<T2> tmp2(X.A.A.B, out); | |
92 const partial_unwrap_check<T3> tmp3(X.A.B, out); | |
93 const partial_unwrap_check<T4> tmp4(X.B, out); | |
94 | |
95 const Mat<eT>& A = tmp1.M; | |
96 const Mat<eT>& B = tmp2.M; | |
97 const Mat<eT>& C = tmp3.M; | |
98 const Mat<eT>& D = tmp4.M; | |
99 | |
100 const bool do_trans_A = tmp1.do_trans; | |
101 const bool do_trans_B = tmp2.do_trans; | |
102 const bool do_trans_C = tmp3.do_trans; | |
103 const bool do_trans_D = tmp4.do_trans; | |
104 | |
105 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times; | |
106 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0); | |
107 | |
108 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); | |
109 } | |
110 | |
111 | |
112 | |
113 template<typename T1, typename T2> | |
114 inline | |
115 void | |
116 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X) | |
117 { | |
118 arma_extra_debug_sigprint(); | |
119 | |
120 typedef typename T1::elem_type eT; | |
121 | |
122 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; | |
123 | |
124 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); | |
125 | |
126 glue_times_redirect<N_mat>::apply(out, X); | |
127 } | |
128 | |
129 | |
130 | |
131 template<typename T1> | |
132 inline | |
133 void | |
134 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X) | |
135 { | |
136 arma_extra_debug_sigprint(); | |
137 | |
138 typedef typename T1::elem_type eT; | |
139 | |
140 const unwrap_check<T1> tmp(X, out); | |
141 const Mat<eT>& B = tmp.M; | |
142 | |
143 arma_debug_assert_mul_size(out, B, "matrix multiplication"); | |
144 | |
145 const uword out_n_rows = out.n_rows; | |
146 const uword out_n_cols = out.n_cols; | |
147 | |
148 if(out_n_cols == B.n_cols) | |
149 { | |
150 // size of resulting matrix is the same as 'out' | |
151 | |
152 podarray<eT> tmp(out_n_cols); | |
153 | |
154 eT* tmp_rowdata = tmp.memptr(); | |
155 | |
156 for(uword row=0; row < out_n_rows; ++row) | |
157 { | |
158 tmp.copy_row(out, row); | |
159 | |
160 for(uword col=0; col < out_n_cols; ++col) | |
161 { | |
162 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) ); | |
163 } | |
164 } | |
165 | |
166 } | |
167 else | |
168 { | |
169 const Mat<eT> tmp(out); | |
170 glue_times::apply(out, tmp, B, eT(1), false, false, false); | |
171 } | |
172 | |
173 } | |
174 | |
175 | |
176 | |
177 template<typename T1, typename T2> | |
178 arma_hot | |
179 inline | |
180 void | |
181 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign) | |
182 { | |
183 arma_extra_debug_sigprint(); | |
184 | |
185 typedef typename T1::elem_type eT; | |
186 | |
187 const partial_unwrap_check<T1> tmp1(X.A, out); | |
188 const partial_unwrap_check<T2> tmp2(X.B, out); | |
189 | |
190 const Mat<eT>& A = tmp1.M; | |
191 const Mat<eT>& B = tmp2.M; | |
192 const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ); | |
193 | |
194 const bool do_trans_A = tmp1.do_trans; | |
195 const bool do_trans_B = tmp2.do_trans; | |
196 const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0)); | |
197 | |
198 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); | |
199 | |
200 const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; | |
201 const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; | |
202 | |
203 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition"); | |
204 | |
205 if(out.n_elem > 0) | |
206 { | |
207 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) | |
208 { | |
209 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
210 { | |
211 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
212 } | |
213 else | |
214 if(B.n_cols == 1) | |
215 { | |
216 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
217 } | |
218 else | |
219 { | |
220 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); | |
221 } | |
222 } | |
223 else | |
224 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) | |
225 { | |
226 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
227 { | |
228 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
229 } | |
230 else | |
231 if(B.n_cols == 1) | |
232 { | |
233 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
234 } | |
235 else | |
236 { | |
237 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); | |
238 } | |
239 } | |
240 else | |
241 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) | |
242 { | |
243 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
244 { | |
245 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
246 } | |
247 else | |
248 if(B.n_cols == 1) | |
249 { | |
250 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
251 } | |
252 else | |
253 { | |
254 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); | |
255 } | |
256 } | |
257 else | |
258 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) | |
259 { | |
260 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
261 { | |
262 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
263 } | |
264 else | |
265 if(B.n_cols == 1) | |
266 { | |
267 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
268 } | |
269 else | |
270 { | |
271 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); | |
272 } | |
273 } | |
274 else | |
275 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) | |
276 { | |
277 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
278 { | |
279 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
280 } | |
281 else | |
282 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
283 { | |
284 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
285 } | |
286 else | |
287 { | |
288 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); | |
289 } | |
290 } | |
291 else | |
292 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) | |
293 { | |
294 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
295 { | |
296 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
297 } | |
298 else | |
299 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
300 { | |
301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
302 } | |
303 else | |
304 { | |
305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); | |
306 } | |
307 } | |
308 else | |
309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) | |
310 { | |
311 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
312 { | |
313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
314 } | |
315 else | |
316 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
317 { | |
318 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
319 } | |
320 else | |
321 { | |
322 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); | |
323 } | |
324 } | |
325 else | |
326 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) | |
327 { | |
328 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
329 { | |
330 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); | |
331 } | |
332 else | |
333 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
334 { | |
335 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); | |
336 } | |
337 else | |
338 { | |
339 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); | |
340 } | |
341 } | |
342 } | |
343 | |
344 | |
345 } | |
346 | |
347 | |
348 | |
349 template<typename eT> | |
350 arma_inline | |
351 uword | |
352 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B) | |
353 { | |
354 const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; | |
355 const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; | |
356 | |
357 return final_A_n_rows * final_B_n_cols; | |
358 } | |
359 | |
360 | |
361 | |
362 template<typename eT> | |
363 arma_hot | |
364 inline | |
365 void | |
366 glue_times::apply | |
367 ( | |
368 Mat<eT>& out, | |
369 const Mat<eT>& A, | |
370 const Mat<eT>& B, | |
371 const eT alpha, | |
372 const bool do_trans_A, | |
373 const bool do_trans_B, | |
374 const bool use_alpha | |
375 ) | |
376 { | |
377 arma_extra_debug_sigprint(); | |
378 | |
379 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication"); | |
380 | |
381 const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; | |
382 const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; | |
383 | |
384 out.set_size(final_n_rows, final_n_cols); | |
385 | |
386 if( (A.n_elem > 0) && (B.n_elem > 0) ) | |
387 { | |
388 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) | |
389 { | |
390 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
391 { | |
392 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); | |
393 } | |
394 else | |
395 if(B.n_cols == 1) | |
396 { | |
397 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); | |
398 } | |
399 else | |
400 { | |
401 gemm<false, false, false, false>::apply(out, A, B); | |
402 } | |
403 } | |
404 else | |
405 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) | |
406 { | |
407 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
408 { | |
409 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); | |
410 } | |
411 else | |
412 if(B.n_cols == 1) | |
413 { | |
414 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); | |
415 } | |
416 else | |
417 { | |
418 gemm<false, false, true, false>::apply(out, A, B, alpha); | |
419 } | |
420 } | |
421 else | |
422 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) | |
423 { | |
424 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
425 { | |
426 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); | |
427 } | |
428 else | |
429 if(B.n_cols == 1) | |
430 { | |
431 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); | |
432 } | |
433 else | |
434 { | |
435 gemm<true, false, false, false>::apply(out, A, B); | |
436 } | |
437 } | |
438 else | |
439 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) | |
440 { | |
441 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
442 { | |
443 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); | |
444 } | |
445 else | |
446 if(B.n_cols == 1) | |
447 { | |
448 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); | |
449 } | |
450 else | |
451 { | |
452 gemm<true, false, true, false>::apply(out, A, B, alpha); | |
453 } | |
454 } | |
455 else | |
456 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) | |
457 { | |
458 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
459 { | |
460 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); | |
461 } | |
462 else | |
463 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
464 { | |
465 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); | |
466 } | |
467 else | |
468 { | |
469 gemm<false, true, false, false>::apply(out, A, B); | |
470 } | |
471 } | |
472 else | |
473 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) | |
474 { | |
475 if( (A.n_rows == 1) && (is_complex<eT>::value == false) ) | |
476 { | |
477 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); | |
478 } | |
479 else | |
480 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
481 { | |
482 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); | |
483 } | |
484 else | |
485 { | |
486 gemm<false, true, true, false>::apply(out, A, B, alpha); | |
487 } | |
488 } | |
489 else | |
490 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) | |
491 { | |
492 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
493 { | |
494 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); | |
495 } | |
496 else | |
497 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
498 { | |
499 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); | |
500 } | |
501 else | |
502 { | |
503 gemm<true, true, false, false>::apply(out, A, B); | |
504 } | |
505 } | |
506 else | |
507 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) | |
508 { | |
509 if( (A.n_cols == 1) && (is_complex<eT>::value == false) ) | |
510 { | |
511 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); | |
512 } | |
513 else | |
514 if( (B.n_rows == 1) && (is_complex<eT>::value == false) ) | |
515 { | |
516 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); | |
517 } | |
518 else | |
519 { | |
520 gemm<true, true, true, false>::apply(out, A, B, alpha); | |
521 } | |
522 } | |
523 } | |
524 else | |
525 { | |
526 out.zeros(); | |
527 } | |
528 } | |
529 | |
530 | |
531 | |
532 template<typename eT> | |
533 inline | |
534 void | |
535 glue_times::apply | |
536 ( | |
537 Mat<eT>& out, | |
538 const Mat<eT>& A, | |
539 const Mat<eT>& B, | |
540 const Mat<eT>& C, | |
541 const eT alpha, | |
542 const bool do_trans_A, | |
543 const bool do_trans_B, | |
544 const bool do_trans_C, | |
545 const bool use_alpha | |
546 ) | |
547 { | |
548 arma_extra_debug_sigprint(); | |
549 | |
550 Mat<eT> tmp; | |
551 | |
552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) | |
553 { | |
554 // out = (A*B)*C | |
555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); | |
556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); | |
557 } | |
558 else | |
559 { | |
560 // out = A*(B*C) | |
561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); | |
562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); | |
563 } | |
564 } | |
565 | |
566 | |
567 | |
568 template<typename eT> | |
569 inline | |
570 void | |
571 glue_times::apply | |
572 ( | |
573 Mat<eT>& out, | |
574 const Mat<eT>& A, | |
575 const Mat<eT>& B, | |
576 const Mat<eT>& C, | |
577 const Mat<eT>& D, | |
578 const eT alpha, | |
579 const bool do_trans_A, | |
580 const bool do_trans_B, | |
581 const bool do_trans_C, | |
582 const bool do_trans_D, | |
583 const bool use_alpha | |
584 ) | |
585 { | |
586 arma_extra_debug_sigprint(); | |
587 | |
588 Mat<eT> tmp; | |
589 | |
590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) | |
591 { | |
592 // out = (A*B*C)*D | |
593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); | |
594 | |
595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); | |
596 } | |
597 else | |
598 { | |
599 // out = A*(B*C*D) | |
600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); | |
601 | |
602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); | |
603 } | |
604 } | |
605 | |
606 | |
607 | |
608 // | |
609 // glue_times_diag | |
610 | |
611 | |
612 template<typename T1, typename T2> | |
613 arma_hot | |
614 inline | |
615 void | |
616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X) | |
617 { | |
618 arma_extra_debug_sigprint(); | |
619 | |
620 typedef typename T1::elem_type eT; | |
621 | |
622 const strip_diagmat<T1> S1(X.A); | |
623 const strip_diagmat<T2> S2(X.B); | |
624 | |
625 typedef typename strip_diagmat<T1>::stored_type T1_stripped; | |
626 typedef typename strip_diagmat<T2>::stored_type T2_stripped; | |
627 | |
628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) | |
629 { | |
630 const diagmat_proxy_check<T1_stripped> A(S1.M, out); | |
631 | |
632 const unwrap_check<T2> tmp(X.B, out); | |
633 const Mat<eT>& B = tmp.M; | |
634 | |
635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication"); | |
636 | |
637 out.set_size(A.n_elem, B.n_cols); | |
638 | |
639 for(uword col=0; col<B.n_cols; ++col) | |
640 { | |
641 eT* out_coldata = out.colptr(col); | |
642 const eT* B_coldata = B.colptr(col); | |
643 | |
644 for(uword row=0; row<B.n_rows; ++row) | |
645 { | |
646 out_coldata[row] = A[row] * B_coldata[row]; | |
647 } | |
648 } | |
649 } | |
650 else | |
651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) ) | |
652 { | |
653 const unwrap_check<T1> tmp(X.A, out); | |
654 const Mat<eT>& A = tmp.M; | |
655 | |
656 const diagmat_proxy_check<T2_stripped> B(S2.M, out); | |
657 | |
658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication"); | |
659 | |
660 out.set_size(A.n_rows, B.n_elem); | |
661 | |
662 for(uword col=0; col<A.n_cols; ++col) | |
663 { | |
664 const eT val = B[col]; | |
665 | |
666 eT* out_coldata = out.colptr(col); | |
667 const eT* A_coldata = A.colptr(col); | |
668 | |
669 for(uword row=0; row<A.n_rows; ++row) | |
670 { | |
671 out_coldata[row] = A_coldata[row] * val; | |
672 } | |
673 } | |
674 } | |
675 else | |
676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) ) | |
677 { | |
678 const diagmat_proxy_check<T1_stripped> A(S1.M, out); | |
679 const diagmat_proxy_check<T2_stripped> B(S2.M, out); | |
680 | |
681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication"); | |
682 | |
683 out.zeros(A.n_elem, A.n_elem); | |
684 | |
685 for(uword i=0; i<A.n_elem; ++i) | |
686 { | |
687 out.at(i,i) = A[i] * B[i]; | |
688 } | |
689 } | |
690 } | |
691 | |
692 | |
693 | |
694 //! @} |