comparison armadillo-2.4.4/include/armadillo_bits/glue_times_meat.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12
13
14 //! \addtogroup glue_times
15 //! @{
16
17
18
19 template<uword N>
20 template<typename T1, typename T2>
21 inline
22 void
23 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
24 {
25 arma_extra_debug_sigprint();
26
27 typedef typename T1::elem_type eT;
28
29 const partial_unwrap_check<T1> tmp1(X.A, out);
30 const partial_unwrap_check<T2> tmp2(X.B, out);
31
32 const Mat<eT>& A = tmp1.M;
33 const Mat<eT>& B = tmp2.M;
34
35 const bool do_trans_A = tmp1.do_trans;
36 const bool do_trans_B = tmp2.do_trans;
37
38 const bool use_alpha = tmp1.do_times || tmp2.do_times;
39 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
40
41 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
42 }
43
44
45
46 template<typename T1, typename T2, typename T3>
47 inline
48 void
49 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
50 {
51 arma_extra_debug_sigprint();
52
53 typedef typename T1::elem_type eT;
54
55 // there is exactly 3 objects
56 // hence we can safely expand X as X.A.A, X.A.B and X.B
57
58 const partial_unwrap_check<T1> tmp1(X.A.A, out);
59 const partial_unwrap_check<T2> tmp2(X.A.B, out);
60 const partial_unwrap_check<T3> tmp3(X.B, out);
61
62 const Mat<eT>& A = tmp1.M;
63 const Mat<eT>& B = tmp2.M;
64 const Mat<eT>& C = tmp3.M;
65
66 const bool do_trans_A = tmp1.do_trans;
67 const bool do_trans_B = tmp2.do_trans;
68 const bool do_trans_C = tmp3.do_trans;
69
70 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times;
71 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
72
73 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
74 }
75
76
77
78 template<typename T1, typename T2, typename T3, typename T4>
79 inline
80 void
81 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
82 {
83 arma_extra_debug_sigprint();
84
85 typedef typename T1::elem_type eT;
86
87 // there is exactly 4 objects
88 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
89
90 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
91 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
92 const partial_unwrap_check<T3> tmp3(X.A.B, out);
93 const partial_unwrap_check<T4> tmp4(X.B, out);
94
95 const Mat<eT>& A = tmp1.M;
96 const Mat<eT>& B = tmp2.M;
97 const Mat<eT>& C = tmp3.M;
98 const Mat<eT>& D = tmp4.M;
99
100 const bool do_trans_A = tmp1.do_trans;
101 const bool do_trans_B = tmp2.do_trans;
102 const bool do_trans_C = tmp3.do_trans;
103 const bool do_trans_D = tmp4.do_trans;
104
105 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times;
106 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
107
108 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
109 }
110
111
112
113 template<typename T1, typename T2>
114 inline
115 void
116 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
117 {
118 arma_extra_debug_sigprint();
119
120 typedef typename T1::elem_type eT;
121
122 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
123
124 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
125
126 glue_times_redirect<N_mat>::apply(out, X);
127 }
128
129
130
131 template<typename T1>
132 inline
133 void
134 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
135 {
136 arma_extra_debug_sigprint();
137
138 typedef typename T1::elem_type eT;
139
140 const unwrap_check<T1> tmp(X, out);
141 const Mat<eT>& B = tmp.M;
142
143 arma_debug_assert_mul_size(out, B, "matrix multiplication");
144
145 const uword out_n_rows = out.n_rows;
146 const uword out_n_cols = out.n_cols;
147
148 if(out_n_cols == B.n_cols)
149 {
150 // size of resulting matrix is the same as 'out'
151
152 podarray<eT> tmp(out_n_cols);
153
154 eT* tmp_rowdata = tmp.memptr();
155
156 for(uword row=0; row < out_n_rows; ++row)
157 {
158 tmp.copy_row(out, row);
159
160 for(uword col=0; col < out_n_cols; ++col)
161 {
162 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
163 }
164 }
165
166 }
167 else
168 {
169 const Mat<eT> tmp(out);
170 glue_times::apply(out, tmp, B, eT(1), false, false, false);
171 }
172
173 }
174
175
176
177 template<typename T1, typename T2>
178 arma_hot
179 inline
180 void
181 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
182 {
183 arma_extra_debug_sigprint();
184
185 typedef typename T1::elem_type eT;
186
187 const partial_unwrap_check<T1> tmp1(X.A, out);
188 const partial_unwrap_check<T2> tmp2(X.B, out);
189
190 const Mat<eT>& A = tmp1.M;
191 const Mat<eT>& B = tmp2.M;
192 const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) );
193
194 const bool do_trans_A = tmp1.do_trans;
195 const bool do_trans_B = tmp2.do_trans;
196 const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0));
197
198 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
199
200 const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
201 const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
202
203 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
204
205 if(out.n_elem > 0)
206 {
207 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
208 {
209 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
210 {
211 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
212 }
213 else
214 if(B.n_cols == 1)
215 {
216 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
217 }
218 else
219 {
220 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
221 }
222 }
223 else
224 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
225 {
226 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
227 {
228 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
229 }
230 else
231 if(B.n_cols == 1)
232 {
233 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
234 }
235 else
236 {
237 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
238 }
239 }
240 else
241 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
242 {
243 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
244 {
245 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
246 }
247 else
248 if(B.n_cols == 1)
249 {
250 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
251 }
252 else
253 {
254 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
255 }
256 }
257 else
258 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
259 {
260 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
261 {
262 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
263 }
264 else
265 if(B.n_cols == 1)
266 {
267 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
268 }
269 else
270 {
271 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
272 }
273 }
274 else
275 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
276 {
277 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
278 {
279 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
280 }
281 else
282 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
283 {
284 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
285 }
286 else
287 {
288 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
289 }
290 }
291 else
292 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
293 {
294 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
295 {
296 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
297 }
298 else
299 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
300 {
301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
302 }
303 else
304 {
305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
306 }
307 }
308 else
309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
310 {
311 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
312 {
313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
314 }
315 else
316 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
317 {
318 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
319 }
320 else
321 {
322 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
323 }
324 }
325 else
326 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
327 {
328 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
329 {
330 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
331 }
332 else
333 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
334 {
335 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
336 }
337 else
338 {
339 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
340 }
341 }
342 }
343
344
345 }
346
347
348
349 template<typename eT>
350 arma_inline
351 uword
352 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
353 {
354 const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
355 const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
356
357 return final_A_n_rows * final_B_n_cols;
358 }
359
360
361
362 template<typename eT>
363 arma_hot
364 inline
365 void
366 glue_times::apply
367 (
368 Mat<eT>& out,
369 const Mat<eT>& A,
370 const Mat<eT>& B,
371 const eT alpha,
372 const bool do_trans_A,
373 const bool do_trans_B,
374 const bool use_alpha
375 )
376 {
377 arma_extra_debug_sigprint();
378
379 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
380
381 const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
382 const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
383
384 out.set_size(final_n_rows, final_n_cols);
385
386 if( (A.n_elem > 0) && (B.n_elem > 0) )
387 {
388 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
389 {
390 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
391 {
392 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
393 }
394 else
395 if(B.n_cols == 1)
396 {
397 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
398 }
399 else
400 {
401 gemm<false, false, false, false>::apply(out, A, B);
402 }
403 }
404 else
405 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
406 {
407 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
408 {
409 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
410 }
411 else
412 if(B.n_cols == 1)
413 {
414 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
415 }
416 else
417 {
418 gemm<false, false, true, false>::apply(out, A, B, alpha);
419 }
420 }
421 else
422 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
423 {
424 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
425 {
426 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
427 }
428 else
429 if(B.n_cols == 1)
430 {
431 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
432 }
433 else
434 {
435 gemm<true, false, false, false>::apply(out, A, B);
436 }
437 }
438 else
439 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
440 {
441 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
442 {
443 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
444 }
445 else
446 if(B.n_cols == 1)
447 {
448 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
449 }
450 else
451 {
452 gemm<true, false, true, false>::apply(out, A, B, alpha);
453 }
454 }
455 else
456 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
457 {
458 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
459 {
460 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
461 }
462 else
463 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
464 {
465 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
466 }
467 else
468 {
469 gemm<false, true, false, false>::apply(out, A, B);
470 }
471 }
472 else
473 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
474 {
475 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
476 {
477 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
478 }
479 else
480 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
481 {
482 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
483 }
484 else
485 {
486 gemm<false, true, true, false>::apply(out, A, B, alpha);
487 }
488 }
489 else
490 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
491 {
492 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
493 {
494 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
495 }
496 else
497 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
498 {
499 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
500 }
501 else
502 {
503 gemm<true, true, false, false>::apply(out, A, B);
504 }
505 }
506 else
507 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
508 {
509 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
510 {
511 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
512 }
513 else
514 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
515 {
516 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
517 }
518 else
519 {
520 gemm<true, true, true, false>::apply(out, A, B, alpha);
521 }
522 }
523 }
524 else
525 {
526 out.zeros();
527 }
528 }
529
530
531
532 template<typename eT>
533 inline
534 void
535 glue_times::apply
536 (
537 Mat<eT>& out,
538 const Mat<eT>& A,
539 const Mat<eT>& B,
540 const Mat<eT>& C,
541 const eT alpha,
542 const bool do_trans_A,
543 const bool do_trans_B,
544 const bool do_trans_C,
545 const bool use_alpha
546 )
547 {
548 arma_extra_debug_sigprint();
549
550 Mat<eT> tmp;
551
552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
553 {
554 // out = (A*B)*C
555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
557 }
558 else
559 {
560 // out = A*(B*C)
561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
563 }
564 }
565
566
567
568 template<typename eT>
569 inline
570 void
571 glue_times::apply
572 (
573 Mat<eT>& out,
574 const Mat<eT>& A,
575 const Mat<eT>& B,
576 const Mat<eT>& C,
577 const Mat<eT>& D,
578 const eT alpha,
579 const bool do_trans_A,
580 const bool do_trans_B,
581 const bool do_trans_C,
582 const bool do_trans_D,
583 const bool use_alpha
584 )
585 {
586 arma_extra_debug_sigprint();
587
588 Mat<eT> tmp;
589
590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
591 {
592 // out = (A*B*C)*D
593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
594
595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
596 }
597 else
598 {
599 // out = A*(B*C*D)
600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
601
602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
603 }
604 }
605
606
607
608 //
609 // glue_times_diag
610
611
612 template<typename T1, typename T2>
613 arma_hot
614 inline
615 void
616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
617 {
618 arma_extra_debug_sigprint();
619
620 typedef typename T1::elem_type eT;
621
622 const strip_diagmat<T1> S1(X.A);
623 const strip_diagmat<T2> S2(X.B);
624
625 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
626 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
627
628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
629 {
630 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
631
632 const unwrap_check<T2> tmp(X.B, out);
633 const Mat<eT>& B = tmp.M;
634
635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication");
636
637 out.set_size(A.n_elem, B.n_cols);
638
639 for(uword col=0; col<B.n_cols; ++col)
640 {
641 eT* out_coldata = out.colptr(col);
642 const eT* B_coldata = B.colptr(col);
643
644 for(uword row=0; row<B.n_rows; ++row)
645 {
646 out_coldata[row] = A[row] * B_coldata[row];
647 }
648 }
649 }
650 else
651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
652 {
653 const unwrap_check<T1> tmp(X.A, out);
654 const Mat<eT>& A = tmp.M;
655
656 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
657
658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication");
659
660 out.set_size(A.n_rows, B.n_elem);
661
662 for(uword col=0; col<A.n_cols; ++col)
663 {
664 const eT val = B[col];
665
666 eT* out_coldata = out.colptr(col);
667 const eT* A_coldata = A.colptr(col);
668
669 for(uword row=0; row<A.n_rows; ++row)
670 {
671 out_coldata[row] = A_coldata[row] * val;
672 }
673 }
674 }
675 else
676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
677 {
678 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
679 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
680
681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication");
682
683 out.zeros(A.n_elem, A.n_elem);
684
685 for(uword i=0; i<A.n_elem; ++i)
686 {
687 out.at(i,i) = A[i] * B[i];
688 }
689 }
690 }
691
692
693
694 //! @}