annotate armadillo-3.900.4/include/armadillo_bits/gemm.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2008-2013 Conrad Sanderson
Chris@49 3 //
Chris@49 4 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 5 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 7
Chris@49 8
Chris@49 9 //! \addtogroup gemm
Chris@49 10 //! @{
Chris@49 11
Chris@49 12
Chris@49 13
Chris@49 14 //! for tiny square matrices, size <= 4x4
Chris@49 15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 16 class gemm_emul_tinysq
Chris@49 17 {
Chris@49 18 public:
Chris@49 19
Chris@49 20
Chris@49 21 template<typename eT, typename TA, typename TB>
Chris@49 22 arma_hot
Chris@49 23 inline
Chris@49 24 static
Chris@49 25 void
Chris@49 26 apply
Chris@49 27 (
Chris@49 28 Mat<eT>& C,
Chris@49 29 const TA& A,
Chris@49 30 const TB& B,
Chris@49 31 const eT alpha = eT(1),
Chris@49 32 const eT beta = eT(0)
Chris@49 33 )
Chris@49 34 {
Chris@49 35 arma_extra_debug_sigprint();
Chris@49 36
Chris@49 37 switch(A.n_rows)
Chris@49 38 {
Chris@49 39 case 4:
Chris@49 40 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
Chris@49 41
Chris@49 42 case 3:
Chris@49 43 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
Chris@49 44
Chris@49 45 case 2:
Chris@49 46 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
Chris@49 47
Chris@49 48 case 1:
Chris@49 49 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
Chris@49 50
Chris@49 51 default:
Chris@49 52 ;
Chris@49 53 }
Chris@49 54 }
Chris@49 55
Chris@49 56 };
Chris@49 57
Chris@49 58
Chris@49 59
Chris@49 60 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 61 class gemm_emul_large
Chris@49 62 {
Chris@49 63 public:
Chris@49 64
Chris@49 65 template<typename eT, typename TA, typename TB>
Chris@49 66 arma_hot
Chris@49 67 inline
Chris@49 68 static
Chris@49 69 void
Chris@49 70 apply
Chris@49 71 (
Chris@49 72 Mat<eT>& C,
Chris@49 73 const TA& A,
Chris@49 74 const TB& B,
Chris@49 75 const eT alpha = eT(1),
Chris@49 76 const eT beta = eT(0)
Chris@49 77 )
Chris@49 78 {
Chris@49 79 arma_extra_debug_sigprint();
Chris@49 80
Chris@49 81 const uword A_n_rows = A.n_rows;
Chris@49 82 const uword A_n_cols = A.n_cols;
Chris@49 83
Chris@49 84 const uword B_n_rows = B.n_rows;
Chris@49 85 const uword B_n_cols = B.n_cols;
Chris@49 86
Chris@49 87 if( (do_trans_A == false) && (do_trans_B == false) )
Chris@49 88 {
Chris@49 89 arma_aligned podarray<eT> tmp(A_n_cols);
Chris@49 90
Chris@49 91 eT* A_rowdata = tmp.memptr();
Chris@49 92
Chris@49 93 for(uword row_A=0; row_A < A_n_rows; ++row_A)
Chris@49 94 {
Chris@49 95 //tmp.copy_row(A, row_A);
Chris@49 96 const eT acc0 = op_dot::dot_and_copy_row(A_rowdata, A, row_A, B.colptr(0), A_n_cols);
Chris@49 97
Chris@49 98 if( (use_alpha == false) && (use_beta == false) )
Chris@49 99 {
Chris@49 100 C.at(row_A,0) = acc0;
Chris@49 101 }
Chris@49 102 else
Chris@49 103 if( (use_alpha == true) && (use_beta == false) )
Chris@49 104 {
Chris@49 105 C.at(row_A,0) = alpha * acc0;
Chris@49 106 }
Chris@49 107 else
Chris@49 108 if( (use_alpha == false) && (use_beta == true) )
Chris@49 109 {
Chris@49 110 C.at(row_A,0) = acc0 + beta*C.at(row_A,0);
Chris@49 111 }
Chris@49 112 else
Chris@49 113 if( (use_alpha == true) && (use_beta == true) )
Chris@49 114 {
Chris@49 115 C.at(row_A,0) = alpha*acc0 + beta*C.at(row_A,0);
Chris@49 116 }
Chris@49 117
Chris@49 118 //for(uword col_B=0; col_B < B_n_cols; ++col_B)
Chris@49 119 for(uword col_B=1; col_B < B_n_cols; ++col_B)
Chris@49 120 {
Chris@49 121 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
Chris@49 122
Chris@49 123 if( (use_alpha == false) && (use_beta == false) )
Chris@49 124 {
Chris@49 125 C.at(row_A,col_B) = acc;
Chris@49 126 }
Chris@49 127 else
Chris@49 128 if( (use_alpha == true) && (use_beta == false) )
Chris@49 129 {
Chris@49 130 C.at(row_A,col_B) = alpha * acc;
Chris@49 131 }
Chris@49 132 else
Chris@49 133 if( (use_alpha == false) && (use_beta == true) )
Chris@49 134 {
Chris@49 135 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
Chris@49 136 }
Chris@49 137 else
Chris@49 138 if( (use_alpha == true) && (use_beta == true) )
Chris@49 139 {
Chris@49 140 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
Chris@49 141 }
Chris@49 142
Chris@49 143 }
Chris@49 144 }
Chris@49 145 }
Chris@49 146 else
Chris@49 147 if( (do_trans_A == true) && (do_trans_B == false) )
Chris@49 148 {
Chris@49 149 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 150 {
Chris@49 151 // col_A is interpreted as row_A when storing the results in matrix C
Chris@49 152
Chris@49 153 const eT* A_coldata = A.colptr(col_A);
Chris@49 154
Chris@49 155 for(uword col_B=0; col_B < B_n_cols; ++col_B)
Chris@49 156 {
Chris@49 157 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
Chris@49 158
Chris@49 159 if( (use_alpha == false) && (use_beta == false) )
Chris@49 160 {
Chris@49 161 C.at(col_A,col_B) = acc;
Chris@49 162 }
Chris@49 163 else
Chris@49 164 if( (use_alpha == true) && (use_beta == false) )
Chris@49 165 {
Chris@49 166 C.at(col_A,col_B) = alpha * acc;
Chris@49 167 }
Chris@49 168 else
Chris@49 169 if( (use_alpha == false) && (use_beta == true) )
Chris@49 170 {
Chris@49 171 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
Chris@49 172 }
Chris@49 173 else
Chris@49 174 if( (use_alpha == true) && (use_beta == true) )
Chris@49 175 {
Chris@49 176 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
Chris@49 177 }
Chris@49 178
Chris@49 179 }
Chris@49 180 }
Chris@49 181 }
Chris@49 182 else
Chris@49 183 if( (do_trans_A == false) && (do_trans_B == true) )
Chris@49 184 {
Chris@49 185 Mat<eT> BB;
Chris@49 186 op_strans::apply_noalias(BB, B);
Chris@49 187
Chris@49 188 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
Chris@49 189 }
Chris@49 190 else
Chris@49 191 if( (do_trans_A == true) && (do_trans_B == true) )
Chris@49 192 {
Chris@49 193 // mat B_tmp = trans(B);
Chris@49 194 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
Chris@49 195
Chris@49 196
Chris@49 197 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
Chris@49 198 // transpose operations are not needed
Chris@49 199
Chris@49 200 arma_aligned podarray<eT> tmp(B.n_cols);
Chris@49 201 eT* B_rowdata = tmp.memptr();
Chris@49 202
Chris@49 203 for(uword row_B=0; row_B < B_n_rows; ++row_B)
Chris@49 204 {
Chris@49 205 tmp.copy_row(B, row_B);
Chris@49 206
Chris@49 207 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 208 {
Chris@49 209 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
Chris@49 210
Chris@49 211 if( (use_alpha == false) && (use_beta == false) )
Chris@49 212 {
Chris@49 213 C.at(col_A,row_B) = acc;
Chris@49 214 }
Chris@49 215 else
Chris@49 216 if( (use_alpha == true) && (use_beta == false) )
Chris@49 217 {
Chris@49 218 C.at(col_A,row_B) = alpha * acc;
Chris@49 219 }
Chris@49 220 else
Chris@49 221 if( (use_alpha == false) && (use_beta == true) )
Chris@49 222 {
Chris@49 223 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
Chris@49 224 }
Chris@49 225 else
Chris@49 226 if( (use_alpha == true) && (use_beta == true) )
Chris@49 227 {
Chris@49 228 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
Chris@49 229 }
Chris@49 230
Chris@49 231 }
Chris@49 232 }
Chris@49 233
Chris@49 234 }
Chris@49 235 }
Chris@49 236
Chris@49 237 };
Chris@49 238
Chris@49 239
Chris@49 240
Chris@49 241 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 242 class gemm_emul
Chris@49 243 {
Chris@49 244 public:
Chris@49 245
Chris@49 246
Chris@49 247 template<typename eT, typename TA, typename TB>
Chris@49 248 arma_hot
Chris@49 249 inline
Chris@49 250 static
Chris@49 251 void
Chris@49 252 apply
Chris@49 253 (
Chris@49 254 Mat<eT>& C,
Chris@49 255 const TA& A,
Chris@49 256 const TB& B,
Chris@49 257 const eT alpha = eT(1),
Chris@49 258 const eT beta = eT(0),
Chris@49 259 const typename arma_not_cx<eT>::result* junk = 0
Chris@49 260 )
Chris@49 261 {
Chris@49 262 arma_extra_debug_sigprint();
Chris@49 263 arma_ignore(junk);
Chris@49 264
Chris@49 265 const uword A_n_rows = A.n_rows;
Chris@49 266 const uword A_n_cols = A.n_cols;
Chris@49 267
Chris@49 268 const uword B_n_rows = B.n_rows;
Chris@49 269 const uword B_n_cols = B.n_cols;
Chris@49 270
Chris@49 271 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
Chris@49 272 {
Chris@49 273 if(do_trans_B == false)
Chris@49 274 {
Chris@49 275 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
Chris@49 276 }
Chris@49 277 else
Chris@49 278 {
Chris@49 279 Mat<eT> BB(A_n_rows, A_n_rows);
Chris@49 280 op_strans::apply_noalias_tinysq(BB, B);
Chris@49 281
Chris@49 282 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
Chris@49 283 }
Chris@49 284 }
Chris@49 285 else
Chris@49 286 {
Chris@49 287 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
Chris@49 288 }
Chris@49 289 }
Chris@49 290
Chris@49 291
Chris@49 292
Chris@49 293 template<typename eT>
Chris@49 294 arma_hot
Chris@49 295 inline
Chris@49 296 static
Chris@49 297 void
Chris@49 298 apply
Chris@49 299 (
Chris@49 300 Mat<eT>& C,
Chris@49 301 const Mat<eT>& A,
Chris@49 302 const Mat<eT>& B,
Chris@49 303 const eT alpha = eT(1),
Chris@49 304 const eT beta = eT(0),
Chris@49 305 const typename arma_cx_only<eT>::result* junk = 0
Chris@49 306 )
Chris@49 307 {
Chris@49 308 arma_extra_debug_sigprint();
Chris@49 309 arma_ignore(junk);
Chris@49 310
Chris@49 311 // "better than nothing" handling of hermitian transposes for complex number matrices
Chris@49 312
Chris@49 313 Mat<eT> tmp_A;
Chris@49 314 Mat<eT> tmp_B;
Chris@49 315
Chris@49 316 if(do_trans_A)
Chris@49 317 {
Chris@49 318 op_htrans::apply_noalias(tmp_A, A);
Chris@49 319 }
Chris@49 320
Chris@49 321 if(do_trans_B)
Chris@49 322 {
Chris@49 323 op_htrans::apply_noalias(tmp_B, B);
Chris@49 324 }
Chris@49 325
Chris@49 326 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
Chris@49 327 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
Chris@49 328
Chris@49 329 const uword A_n_rows = AA.n_rows;
Chris@49 330 const uword A_n_cols = AA.n_cols;
Chris@49 331
Chris@49 332 const uword B_n_rows = BB.n_rows;
Chris@49 333 const uword B_n_cols = BB.n_cols;
Chris@49 334
Chris@49 335 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
Chris@49 336 {
Chris@49 337 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
Chris@49 338 }
Chris@49 339 else
Chris@49 340 {
Chris@49 341 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
Chris@49 342 }
Chris@49 343 }
Chris@49 344
Chris@49 345 };
Chris@49 346
Chris@49 347
Chris@49 348
Chris@49 349 //! \brief
Chris@49 350 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
Chris@49 351 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
Chris@49 352
Chris@49 353 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 354 class gemm
Chris@49 355 {
Chris@49 356 public:
Chris@49 357
Chris@49 358 template<typename eT, typename TA, typename TB>
Chris@49 359 inline
Chris@49 360 static
Chris@49 361 void
Chris@49 362 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
Chris@49 363 {
Chris@49 364 arma_extra_debug_sigprint();
Chris@49 365
Chris@49 366 const uword threshold = (is_Mat_fixed<TA>::value && is_Mat_fixed<TB>::value)
Chris@49 367 ? (is_complex<eT>::value ? 16u : 64u)
Chris@49 368 : (is_complex<eT>::value ? 16u : 48u);
Chris@49 369
Chris@49 370 if( (A.n_elem <= threshold) && (B.n_elem <= threshold) )
Chris@49 371 {
Chris@49 372 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
Chris@49 373 }
Chris@49 374 else
Chris@49 375 {
Chris@49 376 #if defined(ARMA_USE_ATLAS)
Chris@49 377 {
Chris@49 378 arma_extra_debug_print("atlas::cblas_gemm()");
Chris@49 379
Chris@49 380 atlas::cblas_gemm<eT>
Chris@49 381 (
Chris@49 382 atlas::CblasColMajor,
Chris@49 383 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
Chris@49 384 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
Chris@49 385 C.n_rows,
Chris@49 386 C.n_cols,
Chris@49 387 (do_trans_A) ? A.n_rows : A.n_cols,
Chris@49 388 (use_alpha) ? alpha : eT(1),
Chris@49 389 A.mem,
Chris@49 390 (do_trans_A) ? A.n_rows : C.n_rows,
Chris@49 391 B.mem,
Chris@49 392 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
Chris@49 393 (use_beta) ? beta : eT(0),
Chris@49 394 C.memptr(),
Chris@49 395 C.n_rows
Chris@49 396 );
Chris@49 397 }
Chris@49 398 #elif defined(ARMA_USE_BLAS)
Chris@49 399 {
Chris@49 400 arma_extra_debug_print("blas::gemm()");
Chris@49 401
Chris@49 402 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
Chris@49 403 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
Chris@49 404
Chris@49 405 const blas_int m = C.n_rows;
Chris@49 406 const blas_int n = C.n_cols;
Chris@49 407 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
Chris@49 408
Chris@49 409 const eT local_alpha = (use_alpha) ? alpha : eT(1);
Chris@49 410
Chris@49 411 const blas_int lda = (do_trans_A) ? k : m;
Chris@49 412 const blas_int ldb = (do_trans_B) ? n : k;
Chris@49 413
Chris@49 414 const eT local_beta = (use_beta) ? beta : eT(0);
Chris@49 415
Chris@49 416 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
Chris@49 417 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
Chris@49 418
Chris@49 419 blas::gemm<eT>
Chris@49 420 (
Chris@49 421 &trans_A,
Chris@49 422 &trans_B,
Chris@49 423 &m,
Chris@49 424 &n,
Chris@49 425 &k,
Chris@49 426 &local_alpha,
Chris@49 427 A.mem,
Chris@49 428 &lda,
Chris@49 429 B.mem,
Chris@49 430 &ldb,
Chris@49 431 &local_beta,
Chris@49 432 C.memptr(),
Chris@49 433 &m
Chris@49 434 );
Chris@49 435 }
Chris@49 436 #else
Chris@49 437 {
Chris@49 438 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
Chris@49 439 }
Chris@49 440 #endif
Chris@49 441 }
Chris@49 442 }
Chris@49 443
Chris@49 444
Chris@49 445
Chris@49 446 //! immediate multiplication of matrices A and B, storing the result in C
Chris@49 447 template<typename eT, typename TA, typename TB>
Chris@49 448 inline
Chris@49 449 static
Chris@49 450 void
Chris@49 451 apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
Chris@49 452 {
Chris@49 453 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
Chris@49 454 }
Chris@49 455
Chris@49 456
Chris@49 457
Chris@49 458 template<typename TA, typename TB>
Chris@49 459 arma_inline
Chris@49 460 static
Chris@49 461 void
Chris@49 462 apply
Chris@49 463 (
Chris@49 464 Mat<float>& C,
Chris@49 465 const TA& A,
Chris@49 466 const TB& B,
Chris@49 467 const float alpha = float(1),
Chris@49 468 const float beta = float(0)
Chris@49 469 )
Chris@49 470 {
Chris@49 471 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
Chris@49 472 }
Chris@49 473
Chris@49 474
Chris@49 475
Chris@49 476 template<typename TA, typename TB>
Chris@49 477 arma_inline
Chris@49 478 static
Chris@49 479 void
Chris@49 480 apply
Chris@49 481 (
Chris@49 482 Mat<double>& C,
Chris@49 483 const TA& A,
Chris@49 484 const TB& B,
Chris@49 485 const double alpha = double(1),
Chris@49 486 const double beta = double(0)
Chris@49 487 )
Chris@49 488 {
Chris@49 489 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
Chris@49 490 }
Chris@49 491
Chris@49 492
Chris@49 493
Chris@49 494 template<typename TA, typename TB>
Chris@49 495 arma_inline
Chris@49 496 static
Chris@49 497 void
Chris@49 498 apply
Chris@49 499 (
Chris@49 500 Mat< std::complex<float> >& C,
Chris@49 501 const TA& A,
Chris@49 502 const TB& B,
Chris@49 503 const std::complex<float> alpha = std::complex<float>(1),
Chris@49 504 const std::complex<float> beta = std::complex<float>(0)
Chris@49 505 )
Chris@49 506 {
Chris@49 507 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
Chris@49 508 }
Chris@49 509
Chris@49 510
Chris@49 511
Chris@49 512 template<typename TA, typename TB>
Chris@49 513 arma_inline
Chris@49 514 static
Chris@49 515 void
Chris@49 516 apply
Chris@49 517 (
Chris@49 518 Mat< std::complex<double> >& C,
Chris@49 519 const TA& A,
Chris@49 520 const TB& B,
Chris@49 521 const std::complex<double> alpha = std::complex<double>(1),
Chris@49 522 const std::complex<double> beta = std::complex<double>(0)
Chris@49 523 )
Chris@49 524 {
Chris@49 525 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
Chris@49 526 }
Chris@49 527
Chris@49 528 };
Chris@49 529
Chris@49 530
Chris@49 531
Chris@49 532 //! @}