annotate armadillo-3.900.4/include/armadillo_bits/glue_times_meat.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2008-2013 Conrad Sanderson
Chris@49 3 //
Chris@49 4 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 5 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 7
Chris@49 8
Chris@49 9 //! \addtogroup glue_times
Chris@49 10 //! @{
Chris@49 11
Chris@49 12
Chris@49 13
Chris@49 14 template<bool is_eT_blas_type>
Chris@49 15 template<typename T1, typename T2>
Chris@49 16 arma_hot
Chris@49 17 inline
Chris@49 18 void
Chris@49 19 glue_times_redirect2_helper<is_eT_blas_type>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
Chris@49 20 {
Chris@49 21 arma_extra_debug_sigprint();
Chris@49 22
Chris@49 23 typedef typename T1::elem_type eT;
Chris@49 24
Chris@49 25 const partial_unwrap_check<T1> tmp1(X.A, out);
Chris@49 26 const partial_unwrap_check<T2> tmp2(X.B, out);
Chris@49 27
Chris@49 28 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
Chris@49 29 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
Chris@49 30
Chris@49 31 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
Chris@49 32 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
Chris@49 33
Chris@49 34 glue_times::apply
Chris@49 35 <
Chris@49 36 eT,
Chris@49 37 partial_unwrap_check<T1>::do_trans,
Chris@49 38 partial_unwrap_check<T2>::do_trans,
Chris@49 39 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
Chris@49 40 >
Chris@49 41 (out, A, B, alpha);
Chris@49 42 }
Chris@49 43
Chris@49 44
Chris@49 45
Chris@49 46 template<typename T1, typename T2>
Chris@49 47 arma_hot
Chris@49 48 inline
Chris@49 49 void
Chris@49 50 glue_times_redirect2_helper<true>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
Chris@49 51 {
Chris@49 52 arma_extra_debug_sigprint();
Chris@49 53
Chris@49 54 typedef typename T1::elem_type eT;
Chris@49 55
Chris@49 56 if(strip_inv<T1>::do_inv == false)
Chris@49 57 {
Chris@49 58 const partial_unwrap_check<T1> tmp1(X.A, out);
Chris@49 59 const partial_unwrap_check<T2> tmp2(X.B, out);
Chris@49 60
Chris@49 61 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
Chris@49 62 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
Chris@49 63
Chris@49 64 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
Chris@49 65 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
Chris@49 66
Chris@49 67 glue_times::apply
Chris@49 68 <
Chris@49 69 eT,
Chris@49 70 partial_unwrap_check<T1>::do_trans,
Chris@49 71 partial_unwrap_check<T2>::do_trans,
Chris@49 72 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
Chris@49 73 >
Chris@49 74 (out, A, B, alpha);
Chris@49 75 }
Chris@49 76 else
Chris@49 77 {
Chris@49 78 arma_extra_debug_print("glue_times_redirect<2>::apply(): detected inv(A)*B");
Chris@49 79
Chris@49 80 const strip_inv<T1> A_strip(X.A);
Chris@49 81
Chris@49 82 Mat<eT> A = A_strip.M;
Chris@49 83
Chris@49 84 arma_debug_check( (A.is_square() == false), "inv(): given matrix is not square" );
Chris@49 85
Chris@49 86 const unwrap_check<T2> B_tmp(X.B, out);
Chris@49 87 const Mat<eT>& B = B_tmp.M;
Chris@49 88
Chris@49 89 glue_solve::solve_direct( out, A, B, A_strip.slow );
Chris@49 90 }
Chris@49 91 }
Chris@49 92
Chris@49 93
Chris@49 94
Chris@49 95 template<uword N>
Chris@49 96 template<typename T1, typename T2>
Chris@49 97 arma_hot
Chris@49 98 inline
Chris@49 99 void
Chris@49 100 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
Chris@49 101 {
Chris@49 102 arma_extra_debug_sigprint();
Chris@49 103
Chris@49 104 typedef typename T1::elem_type eT;
Chris@49 105
Chris@49 106 const partial_unwrap_check<T1> tmp1(X.A, out);
Chris@49 107 const partial_unwrap_check<T2> tmp2(X.B, out);
Chris@49 108
Chris@49 109 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
Chris@49 110 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
Chris@49 111
Chris@49 112 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times;
Chris@49 113 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
Chris@49 114
Chris@49 115 glue_times::apply
Chris@49 116 <
Chris@49 117 eT,
Chris@49 118 partial_unwrap_check<T1>::do_trans,
Chris@49 119 partial_unwrap_check<T2>::do_trans,
Chris@49 120 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times)
Chris@49 121 >
Chris@49 122 (out, A, B, alpha);
Chris@49 123 }
Chris@49 124
Chris@49 125
Chris@49 126
Chris@49 127 template<typename T1, typename T2>
Chris@49 128 arma_hot
Chris@49 129 inline
Chris@49 130 void
Chris@49 131 glue_times_redirect<2>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
Chris@49 132 {
Chris@49 133 arma_extra_debug_sigprint();
Chris@49 134
Chris@49 135 typedef typename T1::elem_type eT;
Chris@49 136
Chris@49 137 glue_times_redirect2_helper< is_supported_blas_type<eT>::value >::apply(out, X);
Chris@49 138 }
Chris@49 139
Chris@49 140
Chris@49 141
Chris@49 142 template<typename T1, typename T2, typename T3>
Chris@49 143 arma_hot
Chris@49 144 inline
Chris@49 145 void
Chris@49 146 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
Chris@49 147 {
Chris@49 148 arma_extra_debug_sigprint();
Chris@49 149
Chris@49 150 typedef typename T1::elem_type eT;
Chris@49 151
Chris@49 152 // TODO: investigate detecting inv(A)*B*C and replacing with solve(A,B)*C
Chris@49 153 // TODO: investigate detecting A*inv(B)*C and replacing with A*solve(B,C)
Chris@49 154
Chris@49 155 // there is exactly 3 objects
Chris@49 156 // hence we can safely expand X as X.A.A, X.A.B and X.B
Chris@49 157
Chris@49 158 const partial_unwrap_check<T1> tmp1(X.A.A, out);
Chris@49 159 const partial_unwrap_check<T2> tmp2(X.A.B, out);
Chris@49 160 const partial_unwrap_check<T3> tmp3(X.B, out);
Chris@49 161
Chris@49 162 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
Chris@49 163 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
Chris@49 164 const typename partial_unwrap_check<T3>::stored_type& C = tmp3.M;
Chris@49 165
Chris@49 166 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times;
Chris@49 167 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
Chris@49 168
Chris@49 169 glue_times::apply
Chris@49 170 <
Chris@49 171 eT,
Chris@49 172 partial_unwrap_check<T1>::do_trans,
Chris@49 173 partial_unwrap_check<T2>::do_trans,
Chris@49 174 partial_unwrap_check<T3>::do_trans,
Chris@49 175 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times)
Chris@49 176 >
Chris@49 177 (out, A, B, C, alpha);
Chris@49 178 }
Chris@49 179
Chris@49 180
Chris@49 181
Chris@49 182 template<typename T1, typename T2, typename T3, typename T4>
Chris@49 183 arma_hot
Chris@49 184 inline
Chris@49 185 void
Chris@49 186 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
Chris@49 187 {
Chris@49 188 arma_extra_debug_sigprint();
Chris@49 189
Chris@49 190 typedef typename T1::elem_type eT;
Chris@49 191
Chris@49 192 // there is exactly 4 objects
Chris@49 193 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
Chris@49 194
Chris@49 195 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
Chris@49 196 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
Chris@49 197 const partial_unwrap_check<T3> tmp3(X.A.B, out);
Chris@49 198 const partial_unwrap_check<T4> tmp4(X.B, out);
Chris@49 199
Chris@49 200 const typename partial_unwrap_check<T1>::stored_type& A = tmp1.M;
Chris@49 201 const typename partial_unwrap_check<T2>::stored_type& B = tmp2.M;
Chris@49 202 const typename partial_unwrap_check<T3>::stored_type& C = tmp3.M;
Chris@49 203 const typename partial_unwrap_check<T4>::stored_type& D = tmp4.M;
Chris@49 204
Chris@49 205 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times || partial_unwrap_check<T4>::do_times;
Chris@49 206 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
Chris@49 207
Chris@49 208 glue_times::apply
Chris@49 209 <
Chris@49 210 eT,
Chris@49 211 partial_unwrap_check<T1>::do_trans,
Chris@49 212 partial_unwrap_check<T2>::do_trans,
Chris@49 213 partial_unwrap_check<T3>::do_trans,
Chris@49 214 partial_unwrap_check<T4>::do_trans,
Chris@49 215 (partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || partial_unwrap_check<T3>::do_times || partial_unwrap_check<T4>::do_times)
Chris@49 216 >
Chris@49 217 (out, A, B, C, D, alpha);
Chris@49 218 }
Chris@49 219
Chris@49 220
Chris@49 221
Chris@49 222 template<typename T1, typename T2>
Chris@49 223 arma_hot
Chris@49 224 inline
Chris@49 225 void
Chris@49 226 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
Chris@49 227 {
Chris@49 228 arma_extra_debug_sigprint();
Chris@49 229
Chris@49 230 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
Chris@49 231
Chris@49 232 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
Chris@49 233
Chris@49 234 glue_times_redirect<N_mat>::apply(out, X);
Chris@49 235 }
Chris@49 236
Chris@49 237
Chris@49 238
Chris@49 239 template<typename T1>
Chris@49 240 arma_hot
Chris@49 241 inline
Chris@49 242 void
Chris@49 243 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
Chris@49 244 {
Chris@49 245 arma_extra_debug_sigprint();
Chris@49 246
Chris@49 247 typedef typename T1::elem_type eT;
Chris@49 248
Chris@49 249 const unwrap_check<T1> B_tmp(X, out);
Chris@49 250 const Mat<eT>& B = B_tmp.M;
Chris@49 251
Chris@49 252 arma_debug_assert_mul_size(out, B, "matrix multiplication");
Chris@49 253
Chris@49 254 const uword out_n_rows = out.n_rows;
Chris@49 255 const uword out_n_cols = out.n_cols;
Chris@49 256
Chris@49 257 if(out_n_cols == B.n_cols)
Chris@49 258 {
Chris@49 259 // size of resulting matrix is the same as 'out'
Chris@49 260
Chris@49 261 podarray<eT> tmp(out_n_cols);
Chris@49 262
Chris@49 263 eT* tmp_rowdata = tmp.memptr();
Chris@49 264
Chris@49 265 for(uword row=0; row < out_n_rows; ++row)
Chris@49 266 {
Chris@49 267 tmp.copy_row(out, row);
Chris@49 268
Chris@49 269 for(uword col=0; col < out_n_cols; ++col)
Chris@49 270 {
Chris@49 271 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
Chris@49 272 }
Chris@49 273 }
Chris@49 274
Chris@49 275 }
Chris@49 276 else
Chris@49 277 {
Chris@49 278 const Mat<eT> tmp(out);
Chris@49 279
Chris@49 280 glue_times::apply<eT, false, false, false>(out, tmp, B, eT(1));
Chris@49 281 }
Chris@49 282
Chris@49 283 }
Chris@49 284
Chris@49 285
Chris@49 286
Chris@49 287 template<typename T1, typename T2>
Chris@49 288 arma_hot
Chris@49 289 inline
Chris@49 290 void
Chris@49 291 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
Chris@49 292 {
Chris@49 293 arma_extra_debug_sigprint();
Chris@49 294
Chris@49 295 typedef typename T1::elem_type eT;
Chris@49 296
Chris@49 297 const partial_unwrap_check<T1> tmp1(X.A, out);
Chris@49 298 const partial_unwrap_check<T2> tmp2(X.B, out);
Chris@49 299
Chris@49 300 typedef typename partial_unwrap_check<T1>::stored_type TA;
Chris@49 301 typedef typename partial_unwrap_check<T2>::stored_type TB;
Chris@49 302
Chris@49 303 const TA& A = tmp1.M;
Chris@49 304 const TB& B = tmp2.M;
Chris@49 305
Chris@49 306 const bool do_trans_A = partial_unwrap_check<T1>::do_trans;
Chris@49 307 const bool do_trans_B = partial_unwrap_check<T2>::do_trans;
Chris@49 308
Chris@49 309 const bool use_alpha = partial_unwrap_check<T1>::do_times || partial_unwrap_check<T2>::do_times || (sign < sword(0));
Chris@49 310 const eT alpha = use_alpha ? ( tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) ) ) : eT(0);
Chris@49 311
Chris@49 312 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
Chris@49 313
Chris@49 314 const uword result_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
Chris@49 315 const uword result_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
Chris@49 316
Chris@49 317 arma_debug_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
Chris@49 318
Chris@49 319 if(out.n_elem > 0)
Chris@49 320 {
Chris@49 321 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
Chris@49 322 {
Chris@49 323 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 324 {
Chris@49 325 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 326 }
Chris@49 327 else
Chris@49 328 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 329 {
Chris@49 330 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 331 }
Chris@49 332 else
Chris@49 333 {
Chris@49 334 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
Chris@49 335 }
Chris@49 336 }
Chris@49 337 else
Chris@49 338 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
Chris@49 339 {
Chris@49 340 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 341 {
Chris@49 342 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 343 }
Chris@49 344 else
Chris@49 345 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 346 {
Chris@49 347 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 348 }
Chris@49 349 else
Chris@49 350 {
Chris@49 351 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
Chris@49 352 }
Chris@49 353 }
Chris@49 354 else
Chris@49 355 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
Chris@49 356 {
Chris@49 357 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 358 {
Chris@49 359 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 360 }
Chris@49 361 else
Chris@49 362 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 363 {
Chris@49 364 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 365 }
Chris@49 366 else
Chris@49 367 {
Chris@49 368 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
Chris@49 369 }
Chris@49 370 }
Chris@49 371 else
Chris@49 372 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
Chris@49 373 {
Chris@49 374 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 375 {
Chris@49 376 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 377 }
Chris@49 378 else
Chris@49 379 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 380 {
Chris@49 381 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 382 }
Chris@49 383 else
Chris@49 384 {
Chris@49 385 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
Chris@49 386 }
Chris@49 387 }
Chris@49 388 else
Chris@49 389 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
Chris@49 390 {
Chris@49 391 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 392 {
Chris@49 393 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 394 }
Chris@49 395 else
Chris@49 396 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 397 {
Chris@49 398 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 399 }
Chris@49 400 else
Chris@49 401 {
Chris@49 402 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
Chris@49 403 }
Chris@49 404 }
Chris@49 405 else
Chris@49 406 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
Chris@49 407 {
Chris@49 408 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 409 {
Chris@49 410 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 411 }
Chris@49 412 else
Chris@49 413 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 414 {
Chris@49 415 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 416 }
Chris@49 417 else
Chris@49 418 {
Chris@49 419 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
Chris@49 420 }
Chris@49 421 }
Chris@49 422 else
Chris@49 423 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
Chris@49 424 {
Chris@49 425 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 426 {
Chris@49 427 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 428 }
Chris@49 429 else
Chris@49 430 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 431 {
Chris@49 432 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 433 }
Chris@49 434 else
Chris@49 435 {
Chris@49 436 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
Chris@49 437 }
Chris@49 438 }
Chris@49 439 else
Chris@49 440 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
Chris@49 441 {
Chris@49 442 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 443 {
Chris@49 444 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
Chris@49 445 }
Chris@49 446 else
Chris@49 447 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 448 {
Chris@49 449 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
Chris@49 450 }
Chris@49 451 else
Chris@49 452 {
Chris@49 453 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
Chris@49 454 }
Chris@49 455 }
Chris@49 456 }
Chris@49 457
Chris@49 458
Chris@49 459 }
Chris@49 460
Chris@49 461
Chris@49 462
Chris@49 463 template<typename eT, const bool do_trans_A, const bool do_trans_B, typename TA, typename TB>
Chris@49 464 arma_inline
Chris@49 465 uword
Chris@49 466 glue_times::mul_storage_cost(const TA& A, const TB& B)
Chris@49 467 {
Chris@49 468 const uword final_A_n_rows = (do_trans_A == false) ? ( TA::is_row ? 1 : A.n_rows ) : ( TA::is_col ? 1 : A.n_cols );
Chris@49 469 const uword final_B_n_cols = (do_trans_B == false) ? ( TB::is_col ? 1 : B.n_cols ) : ( TB::is_row ? 1 : B.n_rows );
Chris@49 470
Chris@49 471 return final_A_n_rows * final_B_n_cols;
Chris@49 472 }
Chris@49 473
Chris@49 474
Chris@49 475
Chris@49 476 template
Chris@49 477 <
Chris@49 478 typename eT,
Chris@49 479 const bool do_trans_A,
Chris@49 480 const bool do_trans_B,
Chris@49 481 const bool use_alpha,
Chris@49 482 typename TA,
Chris@49 483 typename TB
Chris@49 484 >
Chris@49 485 arma_hot
Chris@49 486 inline
Chris@49 487 void
Chris@49 488 glue_times::apply
Chris@49 489 (
Chris@49 490 Mat<eT>& out,
Chris@49 491 const TA& A,
Chris@49 492 const TB& B,
Chris@49 493 const eT alpha
Chris@49 494 )
Chris@49 495 {
Chris@49 496 arma_extra_debug_sigprint();
Chris@49 497
Chris@49 498 //arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
Chris@49 499 arma_debug_assert_trans_mul_size<do_trans_A, do_trans_B>(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication");
Chris@49 500
Chris@49 501 const uword final_n_rows = (do_trans_A == false) ? (TA::is_row ? 1 : A.n_rows) : (TA::is_col ? 1 : A.n_cols);
Chris@49 502 const uword final_n_cols = (do_trans_B == false) ? (TB::is_col ? 1 : B.n_cols) : (TB::is_row ? 1 : B.n_rows);
Chris@49 503
Chris@49 504 out.set_size(final_n_rows, final_n_cols);
Chris@49 505
Chris@49 506 if( (A.n_elem > 0) && (B.n_elem > 0) )
Chris@49 507 {
Chris@49 508 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
Chris@49 509 {
Chris@49 510 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 511 {
Chris@49 512 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
Chris@49 513 }
Chris@49 514 else
Chris@49 515 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 516 {
Chris@49 517 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
Chris@49 518 }
Chris@49 519 else
Chris@49 520 {
Chris@49 521 gemm<false, false, false, false>::apply(out, A, B);
Chris@49 522 }
Chris@49 523 }
Chris@49 524 else
Chris@49 525 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
Chris@49 526 {
Chris@49 527 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 528 {
Chris@49 529 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
Chris@49 530 }
Chris@49 531 else
Chris@49 532 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 533 {
Chris@49 534 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
Chris@49 535 }
Chris@49 536 else
Chris@49 537 {
Chris@49 538 gemm<false, false, true, false>::apply(out, A, B, alpha);
Chris@49 539 }
Chris@49 540 }
Chris@49 541 else
Chris@49 542 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
Chris@49 543 {
Chris@49 544 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 545 {
Chris@49 546 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
Chris@49 547 }
Chris@49 548 else
Chris@49 549 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 550 {
Chris@49 551 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
Chris@49 552 }
Chris@49 553 else
Chris@49 554 {
Chris@49 555 gemm<true, false, false, false>::apply(out, A, B);
Chris@49 556 }
Chris@49 557 }
Chris@49 558 else
Chris@49 559 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
Chris@49 560 {
Chris@49 561 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 562 {
Chris@49 563 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
Chris@49 564 }
Chris@49 565 else
Chris@49 566 if( (B.n_cols == 1) || (TB::is_col) )
Chris@49 567 {
Chris@49 568 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
Chris@49 569 }
Chris@49 570 else
Chris@49 571 {
Chris@49 572 gemm<true, false, true, false>::apply(out, A, B, alpha);
Chris@49 573 }
Chris@49 574 }
Chris@49 575 else
Chris@49 576 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
Chris@49 577 {
Chris@49 578 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 579 {
Chris@49 580 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
Chris@49 581 }
Chris@49 582 else
Chris@49 583 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 584 {
Chris@49 585 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
Chris@49 586 }
Chris@49 587 else
Chris@49 588 {
Chris@49 589 gemm<false, true, false, false>::apply(out, A, B);
Chris@49 590 }
Chris@49 591 }
Chris@49 592 else
Chris@49 593 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
Chris@49 594 {
Chris@49 595 if( ((A.n_rows == 1) || (TA::is_row)) && (is_complex<eT>::value == false) )
Chris@49 596 {
Chris@49 597 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
Chris@49 598 }
Chris@49 599 else
Chris@49 600 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 601 {
Chris@49 602 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
Chris@49 603 }
Chris@49 604 else
Chris@49 605 {
Chris@49 606 gemm<false, true, true, false>::apply(out, A, B, alpha);
Chris@49 607 }
Chris@49 608 }
Chris@49 609 else
Chris@49 610 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
Chris@49 611 {
Chris@49 612 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 613 {
Chris@49 614 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
Chris@49 615 }
Chris@49 616 else
Chris@49 617 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 618 {
Chris@49 619 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
Chris@49 620 }
Chris@49 621 else
Chris@49 622 {
Chris@49 623 gemm<true, true, false, false>::apply(out, A, B);
Chris@49 624 }
Chris@49 625 }
Chris@49 626 else
Chris@49 627 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
Chris@49 628 {
Chris@49 629 if( ((A.n_cols == 1) || (TA::is_col)) && (is_complex<eT>::value == false) )
Chris@49 630 {
Chris@49 631 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
Chris@49 632 }
Chris@49 633 else
Chris@49 634 if( ((B.n_rows == 1) || (TB::is_row)) && (is_complex<eT>::value == false) )
Chris@49 635 {
Chris@49 636 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
Chris@49 637 }
Chris@49 638 else
Chris@49 639 {
Chris@49 640 gemm<true, true, true, false>::apply(out, A, B, alpha);
Chris@49 641 }
Chris@49 642 }
Chris@49 643 }
Chris@49 644 else
Chris@49 645 {
Chris@49 646 out.zeros();
Chris@49 647 }
Chris@49 648 }
Chris@49 649
Chris@49 650
Chris@49 651
Chris@49 652 template
Chris@49 653 <
Chris@49 654 typename eT,
Chris@49 655 const bool do_trans_A,
Chris@49 656 const bool do_trans_B,
Chris@49 657 const bool do_trans_C,
Chris@49 658 const bool use_alpha,
Chris@49 659 typename TA,
Chris@49 660 typename TB,
Chris@49 661 typename TC
Chris@49 662 >
Chris@49 663 arma_hot
Chris@49 664 inline
Chris@49 665 void
Chris@49 666 glue_times::apply
Chris@49 667 (
Chris@49 668 Mat<eT>& out,
Chris@49 669 const TA& A,
Chris@49 670 const TB& B,
Chris@49 671 const TC& C,
Chris@49 672 const eT alpha
Chris@49 673 )
Chris@49 674 {
Chris@49 675 arma_extra_debug_sigprint();
Chris@49 676
Chris@49 677 Mat<eT> tmp;
Chris@49 678
Chris@49 679 const uword storage_cost_AB = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_B>(A, B);
Chris@49 680 const uword storage_cost_BC = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_C>(B, C);
Chris@49 681
Chris@49 682 if(storage_cost_AB <= storage_cost_BC)
Chris@49 683 {
Chris@49 684 // out = (A*B)*C
Chris@49 685
Chris@49 686 glue_times::apply<eT, do_trans_A, do_trans_B, use_alpha>(tmp, A, B, alpha);
Chris@49 687 glue_times::apply<eT, false, do_trans_C, false >(out, tmp, C, eT(0));
Chris@49 688 }
Chris@49 689 else
Chris@49 690 {
Chris@49 691 // out = A*(B*C)
Chris@49 692
Chris@49 693 glue_times::apply<eT, do_trans_B, do_trans_C, use_alpha>(tmp, B, C, alpha);
Chris@49 694 glue_times::apply<eT, do_trans_A, false, false >(out, A, tmp, eT(0));
Chris@49 695 }
Chris@49 696 }
Chris@49 697
Chris@49 698
Chris@49 699
Chris@49 700 template
Chris@49 701 <
Chris@49 702 typename eT,
Chris@49 703 const bool do_trans_A,
Chris@49 704 const bool do_trans_B,
Chris@49 705 const bool do_trans_C,
Chris@49 706 const bool do_trans_D,
Chris@49 707 const bool use_alpha,
Chris@49 708 typename TA,
Chris@49 709 typename TB,
Chris@49 710 typename TC,
Chris@49 711 typename TD
Chris@49 712 >
Chris@49 713 arma_hot
Chris@49 714 inline
Chris@49 715 void
Chris@49 716 glue_times::apply
Chris@49 717 (
Chris@49 718 Mat<eT>& out,
Chris@49 719 const TA& A,
Chris@49 720 const TB& B,
Chris@49 721 const TC& C,
Chris@49 722 const TD& D,
Chris@49 723 const eT alpha
Chris@49 724 )
Chris@49 725 {
Chris@49 726 arma_extra_debug_sigprint();
Chris@49 727
Chris@49 728 Mat<eT> tmp;
Chris@49 729
Chris@49 730 const uword storage_cost_AC = glue_times::mul_storage_cost<eT, do_trans_A, do_trans_C>(A, C);
Chris@49 731 const uword storage_cost_BD = glue_times::mul_storage_cost<eT, do_trans_B, do_trans_D>(B, D);
Chris@49 732
Chris@49 733 if(storage_cost_AC <= storage_cost_BD)
Chris@49 734 {
Chris@49 735 // out = (A*B*C)*D
Chris@49 736
Chris@49 737 glue_times::apply<eT, do_trans_A, do_trans_B, do_trans_C, use_alpha>(tmp, A, B, C, alpha);
Chris@49 738
Chris@49 739 glue_times::apply<eT, false, do_trans_D, false>(out, tmp, D, eT(0));
Chris@49 740 }
Chris@49 741 else
Chris@49 742 {
Chris@49 743 // out = A*(B*C*D)
Chris@49 744
Chris@49 745 glue_times::apply<eT, do_trans_B, do_trans_C, do_trans_D, use_alpha>(tmp, B, C, D, alpha);
Chris@49 746
Chris@49 747 glue_times::apply<eT, do_trans_A, false, false>(out, A, tmp, eT(0));
Chris@49 748 }
Chris@49 749 }
Chris@49 750
Chris@49 751
Chris@49 752
Chris@49 753 //
Chris@49 754 // glue_times_diag
Chris@49 755
Chris@49 756
Chris@49 757 template<typename T1, typename T2>
Chris@49 758 arma_hot
Chris@49 759 inline
Chris@49 760 void
Chris@49 761 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
Chris@49 762 {
Chris@49 763 arma_extra_debug_sigprint();
Chris@49 764
Chris@49 765 typedef typename T1::elem_type eT;
Chris@49 766
Chris@49 767 const strip_diagmat<T1> S1(X.A);
Chris@49 768 const strip_diagmat<T2> S2(X.B);
Chris@49 769
Chris@49 770 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
Chris@49 771 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
Chris@49 772
Chris@49 773 if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == false) )
Chris@49 774 {
Chris@49 775 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
Chris@49 776
Chris@49 777 const unwrap_check<T2> tmp(X.B, out);
Chris@49 778 const Mat<eT>& B = tmp.M;
Chris@49 779
Chris@49 780 const uword A_n_elem = A.n_elem;
Chris@49 781 const uword B_n_rows = B.n_rows;
Chris@49 782 const uword B_n_cols = B.n_cols;
Chris@49 783
Chris@49 784 arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_rows, B_n_cols, "matrix multiplication");
Chris@49 785
Chris@49 786 out.set_size(A_n_elem, B_n_cols);
Chris@49 787
Chris@49 788 for(uword col=0; col < B_n_cols; ++col)
Chris@49 789 {
Chris@49 790 eT* out_coldata = out.colptr(col);
Chris@49 791 const eT* B_coldata = B.colptr(col);
Chris@49 792
Chris@49 793 uword i,j;
Chris@49 794 for(i=0, j=1; j < B_n_rows; i+=2, j+=2)
Chris@49 795 {
Chris@49 796 eT tmp_i = A[i];
Chris@49 797 eT tmp_j = A[j];
Chris@49 798
Chris@49 799 tmp_i *= B_coldata[i];
Chris@49 800 tmp_j *= B_coldata[j];
Chris@49 801
Chris@49 802 out_coldata[i] = tmp_i;
Chris@49 803 out_coldata[j] = tmp_j;
Chris@49 804 }
Chris@49 805
Chris@49 806 if(i < B_n_rows)
Chris@49 807 {
Chris@49 808 out_coldata[i] = A[i] * B_coldata[i];
Chris@49 809 }
Chris@49 810 }
Chris@49 811 }
Chris@49 812 else
Chris@49 813 if( (strip_diagmat<T1>::do_diagmat == false) && (strip_diagmat<T2>::do_diagmat == true) )
Chris@49 814 {
Chris@49 815 const unwrap_check<T1> tmp(X.A, out);
Chris@49 816 const Mat<eT>& A = tmp.M;
Chris@49 817
Chris@49 818 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
Chris@49 819
Chris@49 820 const uword A_n_rows = A.n_rows;
Chris@49 821 const uword A_n_cols = A.n_cols;
Chris@49 822 const uword B_n_elem = B.n_elem;
Chris@49 823
Chris@49 824 arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_elem, B_n_elem, "matrix multiplication");
Chris@49 825
Chris@49 826 out.set_size(A_n_rows, B_n_elem);
Chris@49 827
Chris@49 828 for(uword col=0; col < A_n_cols; ++col)
Chris@49 829 {
Chris@49 830 const eT val = B[col];
Chris@49 831
Chris@49 832 eT* out_coldata = out.colptr(col);
Chris@49 833 const eT* A_coldata = A.colptr(col);
Chris@49 834
Chris@49 835 uword i,j;
Chris@49 836 for(i=0, j=1; j < A_n_rows; i+=2, j+=2)
Chris@49 837 {
Chris@49 838 const eT tmp_i = A_coldata[i] * val;
Chris@49 839 const eT tmp_j = A_coldata[j] * val;
Chris@49 840
Chris@49 841 out_coldata[i] = tmp_i;
Chris@49 842 out_coldata[j] = tmp_j;
Chris@49 843 }
Chris@49 844
Chris@49 845 if(i < A_n_rows)
Chris@49 846 {
Chris@49 847 out_coldata[i] = A_coldata[i] * val;
Chris@49 848 }
Chris@49 849 }
Chris@49 850 }
Chris@49 851 else
Chris@49 852 if( (strip_diagmat<T1>::do_diagmat == true) && (strip_diagmat<T2>::do_diagmat == true) )
Chris@49 853 {
Chris@49 854 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
Chris@49 855 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
Chris@49 856
Chris@49 857 const uword A_n_elem = A.n_elem;
Chris@49 858 const uword B_n_elem = B.n_elem;
Chris@49 859
Chris@49 860 arma_debug_assert_mul_size(A_n_elem, A_n_elem, B_n_elem, B_n_elem, "matrix multiplication");
Chris@49 861
Chris@49 862 out.zeros(A_n_elem, A_n_elem);
Chris@49 863
Chris@49 864 for(uword i=0; i < A_n_elem; ++i)
Chris@49 865 {
Chris@49 866 out.at(i,i) = A[i] * B[i];
Chris@49 867 }
Chris@49 868 }
Chris@49 869 }
Chris@49 870
Chris@49 871
Chris@49 872
Chris@49 873 //! @}