annotate armadillo-2.4.4/include/armadillo_bits/glue_times_meat.hpp @ 18:8d046a9d36aa slimline

Back out rev 13:ac07c60aa798. Like an idiot, I committed a whole pile of unrelated changes in the guise of a single typo fix. Will re-commit in stages
author Chris Cannam
date Thu, 10 May 2012 10:45:44 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup glue_times
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19 template<uword N>
max@0 20 template<typename T1, typename T2>
max@0 21 inline
max@0 22 void
max@0 23 glue_times_redirect<N>::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
max@0 24 {
max@0 25 arma_extra_debug_sigprint();
max@0 26
max@0 27 typedef typename T1::elem_type eT;
max@0 28
max@0 29 const partial_unwrap_check<T1> tmp1(X.A, out);
max@0 30 const partial_unwrap_check<T2> tmp2(X.B, out);
max@0 31
max@0 32 const Mat<eT>& A = tmp1.M;
max@0 33 const Mat<eT>& B = tmp2.M;
max@0 34
max@0 35 const bool do_trans_A = tmp1.do_trans;
max@0 36 const bool do_trans_B = tmp2.do_trans;
max@0 37
max@0 38 const bool use_alpha = tmp1.do_times || tmp2.do_times;
max@0 39 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val()) : eT(0);
max@0 40
max@0 41 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
max@0 42 }
max@0 43
max@0 44
max@0 45
max@0 46 template<typename T1, typename T2, typename T3>
max@0 47 inline
max@0 48 void
max@0 49 glue_times_redirect<3>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue<T1,T2,glue_times>, T3, glue_times>& X)
max@0 50 {
max@0 51 arma_extra_debug_sigprint();
max@0 52
max@0 53 typedef typename T1::elem_type eT;
max@0 54
max@0 55 // there is exactly 3 objects
max@0 56 // hence we can safely expand X as X.A.A, X.A.B and X.B
max@0 57
max@0 58 const partial_unwrap_check<T1> tmp1(X.A.A, out);
max@0 59 const partial_unwrap_check<T2> tmp2(X.A.B, out);
max@0 60 const partial_unwrap_check<T3> tmp3(X.B, out);
max@0 61
max@0 62 const Mat<eT>& A = tmp1.M;
max@0 63 const Mat<eT>& B = tmp2.M;
max@0 64 const Mat<eT>& C = tmp3.M;
max@0 65
max@0 66 const bool do_trans_A = tmp1.do_trans;
max@0 67 const bool do_trans_B = tmp2.do_trans;
max@0 68 const bool do_trans_C = tmp3.do_trans;
max@0 69
max@0 70 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times;
max@0 71 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val()) : eT(0);
max@0 72
max@0 73 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
max@0 74 }
max@0 75
max@0 76
max@0 77
max@0 78 template<typename T1, typename T2, typename T3, typename T4>
max@0 79 inline
max@0 80 void
max@0 81 glue_times_redirect<4>::apply(Mat<typename T1::elem_type>& out, const Glue< Glue< Glue<T1,T2,glue_times>, T3, glue_times>, T4, glue_times>& X)
max@0 82 {
max@0 83 arma_extra_debug_sigprint();
max@0 84
max@0 85 typedef typename T1::elem_type eT;
max@0 86
max@0 87 // there is exactly 4 objects
max@0 88 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B
max@0 89
max@0 90 const partial_unwrap_check<T1> tmp1(X.A.A.A, out);
max@0 91 const partial_unwrap_check<T2> tmp2(X.A.A.B, out);
max@0 92 const partial_unwrap_check<T3> tmp3(X.A.B, out);
max@0 93 const partial_unwrap_check<T4> tmp4(X.B, out);
max@0 94
max@0 95 const Mat<eT>& A = tmp1.M;
max@0 96 const Mat<eT>& B = tmp2.M;
max@0 97 const Mat<eT>& C = tmp3.M;
max@0 98 const Mat<eT>& D = tmp4.M;
max@0 99
max@0 100 const bool do_trans_A = tmp1.do_trans;
max@0 101 const bool do_trans_B = tmp2.do_trans;
max@0 102 const bool do_trans_C = tmp3.do_trans;
max@0 103 const bool do_trans_D = tmp4.do_trans;
max@0 104
max@0 105 const bool use_alpha = tmp1.do_times || tmp2.do_times || tmp3.do_times || tmp4.do_times;
max@0 106 const eT alpha = use_alpha ? (tmp1.get_val() * tmp2.get_val() * tmp3.get_val() * tmp4.get_val()) : eT(0);
max@0 107
max@0 108 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha);
max@0 109 }
max@0 110
max@0 111
max@0 112
max@0 113 template<typename T1, typename T2>
max@0 114 inline
max@0 115 void
max@0 116 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
max@0 117 {
max@0 118 arma_extra_debug_sigprint();
max@0 119
max@0 120 typedef typename T1::elem_type eT;
max@0 121
max@0 122 const sword N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
max@0 123
max@0 124 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
max@0 125
max@0 126 glue_times_redirect<N_mat>::apply(out, X);
max@0 127 }
max@0 128
max@0 129
max@0 130
max@0 131 template<typename T1>
max@0 132 inline
max@0 133 void
max@0 134 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const T1& X)
max@0 135 {
max@0 136 arma_extra_debug_sigprint();
max@0 137
max@0 138 typedef typename T1::elem_type eT;
max@0 139
max@0 140 const unwrap_check<T1> tmp(X, out);
max@0 141 const Mat<eT>& B = tmp.M;
max@0 142
max@0 143 arma_debug_assert_mul_size(out, B, "matrix multiplication");
max@0 144
max@0 145 const uword out_n_rows = out.n_rows;
max@0 146 const uword out_n_cols = out.n_cols;
max@0 147
max@0 148 if(out_n_cols == B.n_cols)
max@0 149 {
max@0 150 // size of resulting matrix is the same as 'out'
max@0 151
max@0 152 podarray<eT> tmp(out_n_cols);
max@0 153
max@0 154 eT* tmp_rowdata = tmp.memptr();
max@0 155
max@0 156 for(uword row=0; row < out_n_rows; ++row)
max@0 157 {
max@0 158 tmp.copy_row(out, row);
max@0 159
max@0 160 for(uword col=0; col < out_n_cols; ++col)
max@0 161 {
max@0 162 out.at(row,col) = op_dot::direct_dot( out_n_cols, tmp_rowdata, B.colptr(col) );
max@0 163 }
max@0 164 }
max@0 165
max@0 166 }
max@0 167 else
max@0 168 {
max@0 169 const Mat<eT> tmp(out);
max@0 170 glue_times::apply(out, tmp, B, eT(1), false, false, false);
max@0 171 }
max@0 172
max@0 173 }
max@0 174
max@0 175
max@0 176
max@0 177 template<typename T1, typename T2>
max@0 178 arma_hot
max@0 179 inline
max@0 180 void
max@0 181 glue_times::apply_inplace_plus(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times>& X, const sword sign)
max@0 182 {
max@0 183 arma_extra_debug_sigprint();
max@0 184
max@0 185 typedef typename T1::elem_type eT;
max@0 186
max@0 187 const partial_unwrap_check<T1> tmp1(X.A, out);
max@0 188 const partial_unwrap_check<T2> tmp2(X.B, out);
max@0 189
max@0 190 const Mat<eT>& A = tmp1.M;
max@0 191 const Mat<eT>& B = tmp2.M;
max@0 192 const eT alpha = tmp1.get_val() * tmp2.get_val() * ( (sign > sword(0)) ? eT(1) : eT(-1) );
max@0 193
max@0 194 const bool do_trans_A = tmp1.do_trans;
max@0 195 const bool do_trans_B = tmp2.do_trans;
max@0 196 const bool use_alpha = tmp1.do_times || tmp2.do_times || (sign < sword(0));
max@0 197
max@0 198 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
max@0 199
max@0 200 const uword result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
max@0 201 const uword result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
max@0 202
max@0 203 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "addition");
max@0 204
max@0 205 if(out.n_elem > 0)
max@0 206 {
max@0 207 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
max@0 208 {
max@0 209 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 210 {
max@0 211 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 212 }
max@0 213 else
max@0 214 if(B.n_cols == 1)
max@0 215 {
max@0 216 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 217 }
max@0 218 else
max@0 219 {
max@0 220 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1));
max@0 221 }
max@0 222 }
max@0 223 else
max@0 224 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
max@0 225 {
max@0 226 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 227 {
max@0 228 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 229 }
max@0 230 else
max@0 231 if(B.n_cols == 1)
max@0 232 {
max@0 233 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 234 }
max@0 235 else
max@0 236 {
max@0 237 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1));
max@0 238 }
max@0 239 }
max@0 240 else
max@0 241 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
max@0 242 {
max@0 243 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 244 {
max@0 245 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 246 }
max@0 247 else
max@0 248 if(B.n_cols == 1)
max@0 249 {
max@0 250 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 251 }
max@0 252 else
max@0 253 {
max@0 254 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1));
max@0 255 }
max@0 256 }
max@0 257 else
max@0 258 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
max@0 259 {
max@0 260 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 261 {
max@0 262 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 263 }
max@0 264 else
max@0 265 if(B.n_cols == 1)
max@0 266 {
max@0 267 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 268 }
max@0 269 else
max@0 270 {
max@0 271 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1));
max@0 272 }
max@0 273 }
max@0 274 else
max@0 275 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
max@0 276 {
max@0 277 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 278 {
max@0 279 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 280 }
max@0 281 else
max@0 282 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 283 {
max@0 284 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 285 }
max@0 286 else
max@0 287 {
max@0 288 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1));
max@0 289 }
max@0 290 }
max@0 291 else
max@0 292 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
max@0 293 {
max@0 294 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 295 {
max@0 296 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 297 }
max@0 298 else
max@0 299 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 300 {
max@0 301 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 302 }
max@0 303 else
max@0 304 {
max@0 305 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1));
max@0 306 }
max@0 307 }
max@0 308 else
max@0 309 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
max@0 310 {
max@0 311 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 312 {
max@0 313 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 314 }
max@0 315 else
max@0 316 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 317 {
max@0 318 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 319 }
max@0 320 else
max@0 321 {
max@0 322 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1));
max@0 323 }
max@0 324 }
max@0 325 else
max@0 326 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
max@0 327 {
max@0 328 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 329 {
max@0 330 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1));
max@0 331 }
max@0 332 else
max@0 333 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 334 {
max@0 335 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1));
max@0 336 }
max@0 337 else
max@0 338 {
max@0 339 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1));
max@0 340 }
max@0 341 }
max@0 342 }
max@0 343
max@0 344
max@0 345 }
max@0 346
max@0 347
max@0 348
max@0 349 template<typename eT>
max@0 350 arma_inline
max@0 351 uword
max@0 352 glue_times::mul_storage_cost(const Mat<eT>& A, const Mat<eT>& B, const bool do_trans_A, const bool do_trans_B)
max@0 353 {
max@0 354 const uword final_A_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
max@0 355 const uword final_B_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
max@0 356
max@0 357 return final_A_n_rows * final_B_n_cols;
max@0 358 }
max@0 359
max@0 360
max@0 361
max@0 362 template<typename eT>
max@0 363 arma_hot
max@0 364 inline
max@0 365 void
max@0 366 glue_times::apply
max@0 367 (
max@0 368 Mat<eT>& out,
max@0 369 const Mat<eT>& A,
max@0 370 const Mat<eT>& B,
max@0 371 const eT alpha,
max@0 372 const bool do_trans_A,
max@0 373 const bool do_trans_B,
max@0 374 const bool use_alpha
max@0 375 )
max@0 376 {
max@0 377 arma_extra_debug_sigprint();
max@0 378
max@0 379 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiplication");
max@0 380
max@0 381 const uword final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols;
max@0 382 const uword final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows;
max@0 383
max@0 384 out.set_size(final_n_rows, final_n_cols);
max@0 385
max@0 386 if( (A.n_elem > 0) && (B.n_elem > 0) )
max@0 387 {
max@0 388 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) )
max@0 389 {
max@0 390 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 391 {
max@0 392 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
max@0 393 }
max@0 394 else
max@0 395 if(B.n_cols == 1)
max@0 396 {
max@0 397 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
max@0 398 }
max@0 399 else
max@0 400 {
max@0 401 gemm<false, false, false, false>::apply(out, A, B);
max@0 402 }
max@0 403 }
max@0 404 else
max@0 405 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) )
max@0 406 {
max@0 407 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 408 {
max@0 409 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
max@0 410 }
max@0 411 else
max@0 412 if(B.n_cols == 1)
max@0 413 {
max@0 414 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
max@0 415 }
max@0 416 else
max@0 417 {
max@0 418 gemm<false, false, true, false>::apply(out, A, B, alpha);
max@0 419 }
max@0 420 }
max@0 421 else
max@0 422 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) )
max@0 423 {
max@0 424 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 425 {
max@0 426 gemv<true, false, false>::apply(out.memptr(), B, A.memptr());
max@0 427 }
max@0 428 else
max@0 429 if(B.n_cols == 1)
max@0 430 {
max@0 431 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
max@0 432 }
max@0 433 else
max@0 434 {
max@0 435 gemm<true, false, false, false>::apply(out, A, B);
max@0 436 }
max@0 437 }
max@0 438 else
max@0 439 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) )
max@0 440 {
max@0 441 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 442 {
max@0 443 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
max@0 444 }
max@0 445 else
max@0 446 if(B.n_cols == 1)
max@0 447 {
max@0 448 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
max@0 449 }
max@0 450 else
max@0 451 {
max@0 452 gemm<true, false, true, false>::apply(out, A, B, alpha);
max@0 453 }
max@0 454 }
max@0 455 else
max@0 456 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) )
max@0 457 {
max@0 458 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 459 {
max@0 460 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
max@0 461 }
max@0 462 else
max@0 463 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 464 {
max@0 465 gemv<false, false, false>::apply(out.memptr(), A, B.memptr());
max@0 466 }
max@0 467 else
max@0 468 {
max@0 469 gemm<false, true, false, false>::apply(out, A, B);
max@0 470 }
max@0 471 }
max@0 472 else
max@0 473 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) )
max@0 474 {
max@0 475 if( (A.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 476 {
max@0 477 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
max@0 478 }
max@0 479 else
max@0 480 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 481 {
max@0 482 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
max@0 483 }
max@0 484 else
max@0 485 {
max@0 486 gemm<false, true, true, false>::apply(out, A, B, alpha);
max@0 487 }
max@0 488 }
max@0 489 else
max@0 490 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) )
max@0 491 {
max@0 492 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 493 {
max@0 494 gemv<false, false, false>::apply(out.memptr(), B, A.memptr());
max@0 495 }
max@0 496 else
max@0 497 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 498 {
max@0 499 gemv<true, false, false>::apply(out.memptr(), A, B.memptr());
max@0 500 }
max@0 501 else
max@0 502 {
max@0 503 gemm<true, true, false, false>::apply(out, A, B);
max@0 504 }
max@0 505 }
max@0 506 else
max@0 507 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) )
max@0 508 {
max@0 509 if( (A.n_cols == 1) && (is_complex<eT>::value == false) )
max@0 510 {
max@0 511 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha);
max@0 512 }
max@0 513 else
max@0 514 if( (B.n_rows == 1) && (is_complex<eT>::value == false) )
max@0 515 {
max@0 516 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha);
max@0 517 }
max@0 518 else
max@0 519 {
max@0 520 gemm<true, true, true, false>::apply(out, A, B, alpha);
max@0 521 }
max@0 522 }
max@0 523 }
max@0 524 else
max@0 525 {
max@0 526 out.zeros();
max@0 527 }
max@0 528 }
max@0 529
max@0 530
max@0 531
max@0 532 template<typename eT>
max@0 533 inline
max@0 534 void
max@0 535 glue_times::apply
max@0 536 (
max@0 537 Mat<eT>& out,
max@0 538 const Mat<eT>& A,
max@0 539 const Mat<eT>& B,
max@0 540 const Mat<eT>& C,
max@0 541 const eT alpha,
max@0 542 const bool do_trans_A,
max@0 543 const bool do_trans_B,
max@0 544 const bool do_trans_C,
max@0 545 const bool use_alpha
max@0 546 )
max@0 547 {
max@0 548 arma_extra_debug_sigprint();
max@0 549
max@0 550 Mat<eT> tmp;
max@0 551
max@0 552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) )
max@0 553 {
max@0 554 // out = (A*B)*C
max@0 555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha);
max@0 556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false );
max@0 557 }
max@0 558 else
max@0 559 {
max@0 560 // out = A*(B*C)
max@0 561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha);
max@0 562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false );
max@0 563 }
max@0 564 }
max@0 565
max@0 566
max@0 567
max@0 568 template<typename eT>
max@0 569 inline
max@0 570 void
max@0 571 glue_times::apply
max@0 572 (
max@0 573 Mat<eT>& out,
max@0 574 const Mat<eT>& A,
max@0 575 const Mat<eT>& B,
max@0 576 const Mat<eT>& C,
max@0 577 const Mat<eT>& D,
max@0 578 const eT alpha,
max@0 579 const bool do_trans_A,
max@0 580 const bool do_trans_B,
max@0 581 const bool do_trans_C,
max@0 582 const bool do_trans_D,
max@0 583 const bool use_alpha
max@0 584 )
max@0 585 {
max@0 586 arma_extra_debug_sigprint();
max@0 587
max@0 588 Mat<eT> tmp;
max@0 589
max@0 590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) )
max@0 591 {
max@0 592 // out = (A*B*C)*D
max@0 593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha);
max@0 594
max@0 595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false);
max@0 596 }
max@0 597 else
max@0 598 {
max@0 599 // out = A*(B*C*D)
max@0 600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha);
max@0 601
max@0 602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false);
max@0 603 }
max@0 604 }
max@0 605
max@0 606
max@0 607
max@0 608 //
max@0 609 // glue_times_diag
max@0 610
max@0 611
max@0 612 template<typename T1, typename T2>
max@0 613 arma_hot
max@0 614 inline
max@0 615 void
max@0 616 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_times_diag>& X)
max@0 617 {
max@0 618 arma_extra_debug_sigprint();
max@0 619
max@0 620 typedef typename T1::elem_type eT;
max@0 621
max@0 622 const strip_diagmat<T1> S1(X.A);
max@0 623 const strip_diagmat<T2> S2(X.B);
max@0 624
max@0 625 typedef typename strip_diagmat<T1>::stored_type T1_stripped;
max@0 626 typedef typename strip_diagmat<T2>::stored_type T2_stripped;
max@0 627
max@0 628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) )
max@0 629 {
max@0 630 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
max@0 631
max@0 632 const unwrap_check<T2> tmp(X.B, out);
max@0 633 const Mat<eT>& B = tmp.M;
max@0 634
max@0 635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiplication");
max@0 636
max@0 637 out.set_size(A.n_elem, B.n_cols);
max@0 638
max@0 639 for(uword col=0; col<B.n_cols; ++col)
max@0 640 {
max@0 641 eT* out_coldata = out.colptr(col);
max@0 642 const eT* B_coldata = B.colptr(col);
max@0 643
max@0 644 for(uword row=0; row<B.n_rows; ++row)
max@0 645 {
max@0 646 out_coldata[row] = A[row] * B_coldata[row];
max@0 647 }
max@0 648 }
max@0 649 }
max@0 650 else
max@0 651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) )
max@0 652 {
max@0 653 const unwrap_check<T1> tmp(X.A, out);
max@0 654 const Mat<eT>& A = tmp.M;
max@0 655
max@0 656 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
max@0 657
max@0 658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiplication");
max@0 659
max@0 660 out.set_size(A.n_rows, B.n_elem);
max@0 661
max@0 662 for(uword col=0; col<A.n_cols; ++col)
max@0 663 {
max@0 664 const eT val = B[col];
max@0 665
max@0 666 eT* out_coldata = out.colptr(col);
max@0 667 const eT* A_coldata = A.colptr(col);
max@0 668
max@0 669 for(uword row=0; row<A.n_rows; ++row)
max@0 670 {
max@0 671 out_coldata[row] = A_coldata[row] * val;
max@0 672 }
max@0 673 }
max@0 674 }
max@0 675 else
max@0 676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) )
max@0 677 {
max@0 678 const diagmat_proxy_check<T1_stripped> A(S1.M, out);
max@0 679 const diagmat_proxy_check<T2_stripped> B(S2.M, out);
max@0 680
max@0 681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiplication");
max@0 682
max@0 683 out.zeros(A.n_elem, A.n_elem);
max@0 684
max@0 685 for(uword i=0; i<A.n_elem; ++i)
max@0 686 {
max@0 687 out.at(i,i) = A[i] * B[i];
max@0 688 }
max@0 689 }
max@0 690 }
max@0 691
max@0 692
max@0 693
max@0 694 //! @}