annotate armadillo-2.4.4/include/armadillo_bits/gemm.hpp @ 5:79b343f3e4b8

In thi version the problem of letters assigned to each segment has been solved.
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 13:48:13 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup gemm
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19 //! for tiny square matrices, size <= 4x4
max@0 20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
max@0 21 class gemm_emul_tinysq
max@0 22 {
max@0 23 public:
max@0 24
max@0 25
max@0 26 template<typename eT>
max@0 27 arma_hot
max@0 28 inline
max@0 29 static
max@0 30 void
max@0 31 apply
max@0 32 (
max@0 33 Mat<eT>& C,
max@0 34 const Mat<eT>& A,
max@0 35 const Mat<eT>& B,
max@0 36 const eT alpha = eT(1),
max@0 37 const eT beta = eT(0)
max@0 38 )
max@0 39 {
max@0 40 arma_extra_debug_sigprint();
max@0 41
max@0 42 switch(A.n_rows)
max@0 43 {
max@0 44 case 4:
max@0 45 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
max@0 46
max@0 47 case 3:
max@0 48 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
max@0 49
max@0 50 case 2:
max@0 51 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
max@0 52
max@0 53 case 1:
max@0 54 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
max@0 55
max@0 56 default:
max@0 57 ;
max@0 58 }
max@0 59 }
max@0 60
max@0 61 };
max@0 62
max@0 63
max@0 64
max@0 65 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 66 class gemm_emul_large
max@0 67 {
max@0 68 public:
max@0 69
max@0 70 template<typename eT>
max@0 71 arma_hot
max@0 72 inline
max@0 73 static
max@0 74 void
max@0 75 apply
max@0 76 (
max@0 77 Mat<eT>& C,
max@0 78 const Mat<eT>& A,
max@0 79 const Mat<eT>& B,
max@0 80 const eT alpha = eT(1),
max@0 81 const eT beta = eT(0)
max@0 82 )
max@0 83 {
max@0 84 arma_extra_debug_sigprint();
max@0 85
max@0 86 const uword A_n_rows = A.n_rows;
max@0 87 const uword A_n_cols = A.n_cols;
max@0 88
max@0 89 const uword B_n_rows = B.n_rows;
max@0 90 const uword B_n_cols = B.n_cols;
max@0 91
max@0 92 if( (do_trans_A == false) && (do_trans_B == false) )
max@0 93 {
max@0 94 arma_aligned podarray<eT> tmp(A_n_cols);
max@0 95 eT* A_rowdata = tmp.memptr();
max@0 96
max@0 97 for(uword row_A=0; row_A < A_n_rows; ++row_A)
max@0 98 {
max@0 99 tmp.copy_row(A, row_A);
max@0 100
max@0 101 for(uword col_B=0; col_B < B_n_cols; ++col_B)
max@0 102 {
max@0 103 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
max@0 104
max@0 105 if( (use_alpha == false) && (use_beta == false) )
max@0 106 {
max@0 107 C.at(row_A,col_B) = acc;
max@0 108 }
max@0 109 else
max@0 110 if( (use_alpha == true) && (use_beta == false) )
max@0 111 {
max@0 112 C.at(row_A,col_B) = alpha * acc;
max@0 113 }
max@0 114 else
max@0 115 if( (use_alpha == false) && (use_beta == true) )
max@0 116 {
max@0 117 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
max@0 118 }
max@0 119 else
max@0 120 if( (use_alpha == true) && (use_beta == true) )
max@0 121 {
max@0 122 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
max@0 123 }
max@0 124
max@0 125 }
max@0 126 }
max@0 127 }
max@0 128 else
max@0 129 if( (do_trans_A == true) && (do_trans_B == false) )
max@0 130 {
max@0 131 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 132 {
max@0 133 // col_A is interpreted as row_A when storing the results in matrix C
max@0 134
max@0 135 const eT* A_coldata = A.colptr(col_A);
max@0 136
max@0 137 for(uword col_B=0; col_B < B_n_cols; ++col_B)
max@0 138 {
max@0 139 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
max@0 140
max@0 141 if( (use_alpha == false) && (use_beta == false) )
max@0 142 {
max@0 143 C.at(col_A,col_B) = acc;
max@0 144 }
max@0 145 else
max@0 146 if( (use_alpha == true) && (use_beta == false) )
max@0 147 {
max@0 148 C.at(col_A,col_B) = alpha * acc;
max@0 149 }
max@0 150 else
max@0 151 if( (use_alpha == false) && (use_beta == true) )
max@0 152 {
max@0 153 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
max@0 154 }
max@0 155 else
max@0 156 if( (use_alpha == true) && (use_beta == true) )
max@0 157 {
max@0 158 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
max@0 159 }
max@0 160
max@0 161 }
max@0 162 }
max@0 163 }
max@0 164 else
max@0 165 if( (do_trans_A == false) && (do_trans_B == true) )
max@0 166 {
max@0 167 Mat<eT> BB;
max@0 168 op_strans::apply_noalias(BB, B);
max@0 169
max@0 170 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
max@0 171 }
max@0 172 else
max@0 173 if( (do_trans_A == true) && (do_trans_B == true) )
max@0 174 {
max@0 175 // mat B_tmp = trans(B);
max@0 176 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
max@0 177
max@0 178
max@0 179 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
max@0 180 // transpose operations are not needed
max@0 181
max@0 182 arma_aligned podarray<eT> tmp(B.n_cols);
max@0 183 eT* B_rowdata = tmp.memptr();
max@0 184
max@0 185 for(uword row_B=0; row_B < B_n_rows; ++row_B)
max@0 186 {
max@0 187 tmp.copy_row(B, row_B);
max@0 188
max@0 189 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 190 {
max@0 191 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
max@0 192
max@0 193 if( (use_alpha == false) && (use_beta == false) )
max@0 194 {
max@0 195 C.at(col_A,row_B) = acc;
max@0 196 }
max@0 197 else
max@0 198 if( (use_alpha == true) && (use_beta == false) )
max@0 199 {
max@0 200 C.at(col_A,row_B) = alpha * acc;
max@0 201 }
max@0 202 else
max@0 203 if( (use_alpha == false) && (use_beta == true) )
max@0 204 {
max@0 205 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
max@0 206 }
max@0 207 else
max@0 208 if( (use_alpha == true) && (use_beta == true) )
max@0 209 {
max@0 210 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
max@0 211 }
max@0 212
max@0 213 }
max@0 214 }
max@0 215
max@0 216 }
max@0 217 }
max@0 218
max@0 219 };
max@0 220
max@0 221
max@0 222
max@0 223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 224 class gemm_emul
max@0 225 {
max@0 226 public:
max@0 227
max@0 228
max@0 229 template<typename eT>
max@0 230 arma_hot
max@0 231 inline
max@0 232 static
max@0 233 void
max@0 234 apply
max@0 235 (
max@0 236 Mat<eT>& C,
max@0 237 const Mat<eT>& A,
max@0 238 const Mat<eT>& B,
max@0 239 const eT alpha = eT(1),
max@0 240 const eT beta = eT(0),
max@0 241 const typename arma_not_cx<eT>::result* junk = 0
max@0 242 )
max@0 243 {
max@0 244 arma_extra_debug_sigprint();
max@0 245 arma_ignore(junk);
max@0 246
max@0 247 const uword A_n_rows = A.n_rows;
max@0 248 const uword A_n_cols = A.n_cols;
max@0 249
max@0 250 const uword B_n_rows = B.n_rows;
max@0 251 const uword B_n_cols = B.n_cols;
max@0 252
max@0 253 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
max@0 254 {
max@0 255 if(do_trans_B == false)
max@0 256 {
max@0 257 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
max@0 258 }
max@0 259 else
max@0 260 {
max@0 261 Mat<eT> BB(A_n_rows, A_n_rows);
max@0 262 op_strans::apply_noalias_tinysq(BB, B);
max@0 263
max@0 264 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
max@0 265 }
max@0 266 }
max@0 267 else
max@0 268 {
max@0 269 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
max@0 270 }
max@0 271 }
max@0 272
max@0 273
max@0 274
max@0 275 template<typename eT>
max@0 276 arma_hot
max@0 277 inline
max@0 278 static
max@0 279 void
max@0 280 apply
max@0 281 (
max@0 282 Mat<eT>& C,
max@0 283 const Mat<eT>& A,
max@0 284 const Mat<eT>& B,
max@0 285 const eT alpha = eT(1),
max@0 286 const eT beta = eT(0),
max@0 287 const typename arma_cx_only<eT>::result* junk = 0
max@0 288 )
max@0 289 {
max@0 290 arma_extra_debug_sigprint();
max@0 291 arma_ignore(junk);
max@0 292
max@0 293 // "better than nothing" handling of hermitian transposes for complex number matrices
max@0 294
max@0 295 Mat<eT> tmp_A;
max@0 296 Mat<eT> tmp_B;
max@0 297
max@0 298 if(do_trans_A)
max@0 299 {
max@0 300 op_htrans::apply_noalias(tmp_A, A);
max@0 301 }
max@0 302
max@0 303 if(do_trans_B)
max@0 304 {
max@0 305 op_htrans::apply_noalias(tmp_B, B);
max@0 306 }
max@0 307
max@0 308 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
max@0 309 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
max@0 310
max@0 311 const uword A_n_rows = AA.n_rows;
max@0 312 const uword A_n_cols = AA.n_cols;
max@0 313
max@0 314 const uword B_n_rows = BB.n_rows;
max@0 315 const uword B_n_cols = BB.n_cols;
max@0 316
max@0 317 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
max@0 318 {
max@0 319 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
max@0 320 }
max@0 321 else
max@0 322 {
max@0 323 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
max@0 324 }
max@0 325 }
max@0 326
max@0 327 };
max@0 328
max@0 329
max@0 330
max@0 331 //! \brief
max@0 332 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
max@0 333 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
max@0 334
max@0 335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 336 class gemm
max@0 337 {
max@0 338 public:
max@0 339
max@0 340 template<typename eT>
max@0 341 inline
max@0 342 static
max@0 343 void
max@0 344 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 345 {
max@0 346 arma_extra_debug_sigprint();
max@0 347
max@0 348 if( (A.n_elem <= 48u) && (B.n_elem <= 48u) )
max@0 349 {
max@0 350 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
max@0 351 }
max@0 352 else
max@0 353 {
max@0 354 #if defined(ARMA_USE_ATLAS)
max@0 355 {
max@0 356 arma_extra_debug_print("atlas::cblas_gemm()");
max@0 357
max@0 358 atlas::cblas_gemm<eT>
max@0 359 (
max@0 360 atlas::CblasColMajor,
max@0 361 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
max@0 362 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
max@0 363 C.n_rows,
max@0 364 C.n_cols,
max@0 365 (do_trans_A) ? A.n_rows : A.n_cols,
max@0 366 (use_alpha) ? alpha : eT(1),
max@0 367 A.mem,
max@0 368 (do_trans_A) ? A.n_rows : C.n_rows,
max@0 369 B.mem,
max@0 370 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
max@0 371 (use_beta) ? beta : eT(0),
max@0 372 C.memptr(),
max@0 373 C.n_rows
max@0 374 );
max@0 375 }
max@0 376 #elif defined(ARMA_USE_BLAS)
max@0 377 {
max@0 378 arma_extra_debug_print("blas::gemm()");
max@0 379
max@0 380 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
max@0 381 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
max@0 382
max@0 383 const blas_int m = C.n_rows;
max@0 384 const blas_int n = C.n_cols;
max@0 385 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
max@0 386
max@0 387 const eT local_alpha = (use_alpha) ? alpha : eT(1);
max@0 388
max@0 389 const blas_int lda = (do_trans_A) ? k : m;
max@0 390 const blas_int ldb = (do_trans_B) ? n : k;
max@0 391
max@0 392 const eT local_beta = (use_beta) ? beta : eT(0);
max@0 393
max@0 394 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
max@0 395 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
max@0 396
max@0 397 blas::gemm<eT>
max@0 398 (
max@0 399 &trans_A,
max@0 400 &trans_B,
max@0 401 &m,
max@0 402 &n,
max@0 403 &k,
max@0 404 &local_alpha,
max@0 405 A.mem,
max@0 406 &lda,
max@0 407 B.mem,
max@0 408 &ldb,
max@0 409 &local_beta,
max@0 410 C.memptr(),
max@0 411 &m
max@0 412 );
max@0 413 }
max@0 414 #else
max@0 415 {
max@0 416 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
max@0 417 }
max@0 418 #endif
max@0 419 }
max@0 420 }
max@0 421
max@0 422
max@0 423
max@0 424 //! immediate multiplication of matrices A and B, storing the result in C
max@0 425 template<typename eT>
max@0 426 inline
max@0 427 static
max@0 428 void
max@0 429 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 430 {
max@0 431 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
max@0 432 }
max@0 433
max@0 434
max@0 435
max@0 436 arma_inline
max@0 437 static
max@0 438 void
max@0 439 apply
max@0 440 (
max@0 441 Mat<float>& C,
max@0 442 const Mat<float>& A,
max@0 443 const Mat<float>& B,
max@0 444 const float alpha = float(1),
max@0 445 const float beta = float(0)
max@0 446 )
max@0 447 {
max@0 448 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
max@0 449 }
max@0 450
max@0 451
max@0 452
max@0 453 arma_inline
max@0 454 static
max@0 455 void
max@0 456 apply
max@0 457 (
max@0 458 Mat<double>& C,
max@0 459 const Mat<double>& A,
max@0 460 const Mat<double>& B,
max@0 461 const double alpha = double(1),
max@0 462 const double beta = double(0)
max@0 463 )
max@0 464 {
max@0 465 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
max@0 466 }
max@0 467
max@0 468
max@0 469
max@0 470 arma_inline
max@0 471 static
max@0 472 void
max@0 473 apply
max@0 474 (
max@0 475 Mat< std::complex<float> >& C,
max@0 476 const Mat< std::complex<float> >& A,
max@0 477 const Mat< std::complex<float> >& B,
max@0 478 const std::complex<float> alpha = std::complex<float>(1),
max@0 479 const std::complex<float> beta = std::complex<float>(0)
max@0 480 )
max@0 481 {
max@0 482 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
max@0 483 }
max@0 484
max@0 485
max@0 486
max@0 487 arma_inline
max@0 488 static
max@0 489 void
max@0 490 apply
max@0 491 (
max@0 492 Mat< std::complex<double> >& C,
max@0 493 const Mat< std::complex<double> >& A,
max@0 494 const Mat< std::complex<double> >& B,
max@0 495 const std::complex<double> alpha = std::complex<double>(1),
max@0 496 const std::complex<double> beta = std::complex<double>(0)
max@0 497 )
max@0 498 {
max@0 499 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
max@0 500 }
max@0 501
max@0 502 };
max@0 503
max@0 504
max@0 505
max@0 506 //! @}