annotate armadillo-3.900.4/include/armadillo_bits/gemm_mixed.hpp @ 84:55a047986812 tip

Update library URI so as not to be document-local
author Chris Cannam
date Wed, 22 Apr 2020 14:21:57 +0100
parents 1ec0e2823891
children
rev   line source
Chris@49 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
Chris@49 2 // Copyright (C) 2008-2011 Conrad Sanderson
Chris@49 3 //
Chris@49 4 // This Source Code Form is subject to the terms of the Mozilla Public
Chris@49 5 // License, v. 2.0. If a copy of the MPL was not distributed with this
Chris@49 6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
Chris@49 7
Chris@49 8
Chris@49 9 //! \addtogroup gemm_mixed
Chris@49 10 //! @{
Chris@49 11
Chris@49 12
Chris@49 13
Chris@49 14 //! \brief
Chris@49 15 //! Matrix multplication where the matrices have differing element types.
Chris@49 16 //! Uses caching for speedup.
Chris@49 17 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
Chris@49 18
Chris@49 19 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 20 class gemm_mixed_large
Chris@49 21 {
Chris@49 22 public:
Chris@49 23
Chris@49 24 template<typename out_eT, typename in_eT1, typename in_eT2>
Chris@49 25 arma_hot
Chris@49 26 inline
Chris@49 27 static
Chris@49 28 void
Chris@49 29 apply
Chris@49 30 (
Chris@49 31 Mat<out_eT>& C,
Chris@49 32 const Mat<in_eT1>& A,
Chris@49 33 const Mat<in_eT2>& B,
Chris@49 34 const out_eT alpha = out_eT(1),
Chris@49 35 const out_eT beta = out_eT(0)
Chris@49 36 )
Chris@49 37 {
Chris@49 38 arma_extra_debug_sigprint();
Chris@49 39
Chris@49 40 const uword A_n_rows = A.n_rows;
Chris@49 41 const uword A_n_cols = A.n_cols;
Chris@49 42
Chris@49 43 const uword B_n_rows = B.n_rows;
Chris@49 44 const uword B_n_cols = B.n_cols;
Chris@49 45
Chris@49 46 if( (do_trans_A == false) && (do_trans_B == false) )
Chris@49 47 {
Chris@49 48 podarray<in_eT1> tmp(A_n_cols);
Chris@49 49 in_eT1* A_rowdata = tmp.memptr();
Chris@49 50
Chris@49 51 for(uword row_A=0; row_A < A_n_rows; ++row_A)
Chris@49 52 {
Chris@49 53 tmp.copy_row(A, row_A);
Chris@49 54
Chris@49 55 for(uword col_B=0; col_B < B_n_cols; ++col_B)
Chris@49 56 {
Chris@49 57 const in_eT2* B_coldata = B.colptr(col_B);
Chris@49 58
Chris@49 59 out_eT acc = out_eT(0);
Chris@49 60 for(uword i=0; i < B_n_rows; ++i)
Chris@49 61 {
Chris@49 62 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
Chris@49 63 }
Chris@49 64
Chris@49 65 if( (use_alpha == false) && (use_beta == false) )
Chris@49 66 {
Chris@49 67 C.at(row_A,col_B) = acc;
Chris@49 68 }
Chris@49 69 else
Chris@49 70 if( (use_alpha == true) && (use_beta == false) )
Chris@49 71 {
Chris@49 72 C.at(row_A,col_B) = alpha * acc;
Chris@49 73 }
Chris@49 74 else
Chris@49 75 if( (use_alpha == false) && (use_beta == true) )
Chris@49 76 {
Chris@49 77 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
Chris@49 78 }
Chris@49 79 else
Chris@49 80 if( (use_alpha == true) && (use_beta == true) )
Chris@49 81 {
Chris@49 82 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
Chris@49 83 }
Chris@49 84
Chris@49 85 }
Chris@49 86 }
Chris@49 87 }
Chris@49 88 else
Chris@49 89 if( (do_trans_A == true) && (do_trans_B == false) )
Chris@49 90 {
Chris@49 91 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 92 {
Chris@49 93 // col_A is interpreted as row_A when storing the results in matrix C
Chris@49 94
Chris@49 95 const in_eT1* A_coldata = A.colptr(col_A);
Chris@49 96
Chris@49 97 for(uword col_B=0; col_B < B_n_cols; ++col_B)
Chris@49 98 {
Chris@49 99 const in_eT2* B_coldata = B.colptr(col_B);
Chris@49 100
Chris@49 101 out_eT acc = out_eT(0);
Chris@49 102 for(uword i=0; i < B_n_rows; ++i)
Chris@49 103 {
Chris@49 104 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
Chris@49 105 }
Chris@49 106
Chris@49 107 if( (use_alpha == false) && (use_beta == false) )
Chris@49 108 {
Chris@49 109 C.at(col_A,col_B) = acc;
Chris@49 110 }
Chris@49 111 else
Chris@49 112 if( (use_alpha == true) && (use_beta == false) )
Chris@49 113 {
Chris@49 114 C.at(col_A,col_B) = alpha * acc;
Chris@49 115 }
Chris@49 116 else
Chris@49 117 if( (use_alpha == false) && (use_beta == true) )
Chris@49 118 {
Chris@49 119 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
Chris@49 120 }
Chris@49 121 else
Chris@49 122 if( (use_alpha == true) && (use_beta == true) )
Chris@49 123 {
Chris@49 124 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
Chris@49 125 }
Chris@49 126
Chris@49 127 }
Chris@49 128 }
Chris@49 129 }
Chris@49 130 else
Chris@49 131 if( (do_trans_A == false) && (do_trans_B == true) )
Chris@49 132 {
Chris@49 133 Mat<in_eT2> B_tmp;
Chris@49 134
Chris@49 135 op_strans::apply_noalias(B_tmp, B);
Chris@49 136
Chris@49 137 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
Chris@49 138 }
Chris@49 139 else
Chris@49 140 if( (do_trans_A == true) && (do_trans_B == true) )
Chris@49 141 {
Chris@49 142 // mat B_tmp = trans(B);
Chris@49 143 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
Chris@49 144
Chris@49 145
Chris@49 146 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
Chris@49 147 // transpose operations are not needed
Chris@49 148
Chris@49 149 podarray<in_eT2> tmp(B_n_cols);
Chris@49 150 in_eT2* B_rowdata = tmp.memptr();
Chris@49 151
Chris@49 152 for(uword row_B=0; row_B < B_n_rows; ++row_B)
Chris@49 153 {
Chris@49 154 tmp.copy_row(B, row_B);
Chris@49 155
Chris@49 156 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 157 {
Chris@49 158 const in_eT1* A_coldata = A.colptr(col_A);
Chris@49 159
Chris@49 160 out_eT acc = out_eT(0);
Chris@49 161 for(uword i=0; i < A_n_rows; ++i)
Chris@49 162 {
Chris@49 163 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
Chris@49 164 }
Chris@49 165
Chris@49 166 if( (use_alpha == false) && (use_beta == false) )
Chris@49 167 {
Chris@49 168 C.at(col_A,row_B) = acc;
Chris@49 169 }
Chris@49 170 else
Chris@49 171 if( (use_alpha == true) && (use_beta == false) )
Chris@49 172 {
Chris@49 173 C.at(col_A,row_B) = alpha * acc;
Chris@49 174 }
Chris@49 175 else
Chris@49 176 if( (use_alpha == false) && (use_beta == true) )
Chris@49 177 {
Chris@49 178 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
Chris@49 179 }
Chris@49 180 else
Chris@49 181 if( (use_alpha == true) && (use_beta == true) )
Chris@49 182 {
Chris@49 183 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
Chris@49 184 }
Chris@49 185
Chris@49 186 }
Chris@49 187 }
Chris@49 188
Chris@49 189 }
Chris@49 190 }
Chris@49 191
Chris@49 192 };
Chris@49 193
Chris@49 194
Chris@49 195
Chris@49 196 //! Matrix multplication where the matrices have different element types.
Chris@49 197 //! Simple version (no caching).
Chris@49 198 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
Chris@49 199 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 200 class gemm_mixed_small
Chris@49 201 {
Chris@49 202 public:
Chris@49 203
Chris@49 204 template<typename out_eT, typename in_eT1, typename in_eT2>
Chris@49 205 arma_hot
Chris@49 206 inline
Chris@49 207 static
Chris@49 208 void
Chris@49 209 apply
Chris@49 210 (
Chris@49 211 Mat<out_eT>& C,
Chris@49 212 const Mat<in_eT1>& A,
Chris@49 213 const Mat<in_eT2>& B,
Chris@49 214 const out_eT alpha = out_eT(1),
Chris@49 215 const out_eT beta = out_eT(0)
Chris@49 216 )
Chris@49 217 {
Chris@49 218 arma_extra_debug_sigprint();
Chris@49 219
Chris@49 220 const uword A_n_rows = A.n_rows;
Chris@49 221 const uword A_n_cols = A.n_cols;
Chris@49 222
Chris@49 223 const uword B_n_rows = B.n_rows;
Chris@49 224 const uword B_n_cols = B.n_cols;
Chris@49 225
Chris@49 226 if( (do_trans_A == false) && (do_trans_B == false) )
Chris@49 227 {
Chris@49 228 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
Chris@49 229 {
Chris@49 230 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
Chris@49 231 {
Chris@49 232 const in_eT2* B_coldata = B.colptr(col_B);
Chris@49 233
Chris@49 234 out_eT acc = out_eT(0);
Chris@49 235 for(uword i = 0; i < B_n_rows; ++i)
Chris@49 236 {
Chris@49 237 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
Chris@49 238 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
Chris@49 239 acc += val1 * val2;
Chris@49 240 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
Chris@49 241 }
Chris@49 242
Chris@49 243 if( (use_alpha == false) && (use_beta == false) )
Chris@49 244 {
Chris@49 245 C.at(row_A,col_B) = acc;
Chris@49 246 }
Chris@49 247 else
Chris@49 248 if( (use_alpha == true) && (use_beta == false) )
Chris@49 249 {
Chris@49 250 C.at(row_A,col_B) = alpha * acc;
Chris@49 251 }
Chris@49 252 else
Chris@49 253 if( (use_alpha == false) && (use_beta == true) )
Chris@49 254 {
Chris@49 255 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
Chris@49 256 }
Chris@49 257 else
Chris@49 258 if( (use_alpha == true) && (use_beta == true) )
Chris@49 259 {
Chris@49 260 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
Chris@49 261 }
Chris@49 262 }
Chris@49 263 }
Chris@49 264 }
Chris@49 265 else
Chris@49 266 if( (do_trans_A == true) && (do_trans_B == false) )
Chris@49 267 {
Chris@49 268 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 269 {
Chris@49 270 // col_A is interpreted as row_A when storing the results in matrix C
Chris@49 271
Chris@49 272 const in_eT1* A_coldata = A.colptr(col_A);
Chris@49 273
Chris@49 274 for(uword col_B=0; col_B < B_n_cols; ++col_B)
Chris@49 275 {
Chris@49 276 const in_eT2* B_coldata = B.colptr(col_B);
Chris@49 277
Chris@49 278 out_eT acc = out_eT(0);
Chris@49 279 for(uword i=0; i < B_n_rows; ++i)
Chris@49 280 {
Chris@49 281 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
Chris@49 282 }
Chris@49 283
Chris@49 284 if( (use_alpha == false) && (use_beta == false) )
Chris@49 285 {
Chris@49 286 C.at(col_A,col_B) = acc;
Chris@49 287 }
Chris@49 288 else
Chris@49 289 if( (use_alpha == true) && (use_beta == false) )
Chris@49 290 {
Chris@49 291 C.at(col_A,col_B) = alpha * acc;
Chris@49 292 }
Chris@49 293 else
Chris@49 294 if( (use_alpha == false) && (use_beta == true) )
Chris@49 295 {
Chris@49 296 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
Chris@49 297 }
Chris@49 298 else
Chris@49 299 if( (use_alpha == true) && (use_beta == true) )
Chris@49 300 {
Chris@49 301 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
Chris@49 302 }
Chris@49 303
Chris@49 304 }
Chris@49 305 }
Chris@49 306 }
Chris@49 307 else
Chris@49 308 if( (do_trans_A == false) && (do_trans_B == true) )
Chris@49 309 {
Chris@49 310 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
Chris@49 311 {
Chris@49 312 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
Chris@49 313 {
Chris@49 314 out_eT acc = out_eT(0);
Chris@49 315 for(uword i = 0; i < B_n_cols; ++i)
Chris@49 316 {
Chris@49 317 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
Chris@49 318 }
Chris@49 319
Chris@49 320 if( (use_alpha == false) && (use_beta == false) )
Chris@49 321 {
Chris@49 322 C.at(row_A,row_B) = acc;
Chris@49 323 }
Chris@49 324 else
Chris@49 325 if( (use_alpha == true) && (use_beta == false) )
Chris@49 326 {
Chris@49 327 C.at(row_A,row_B) = alpha * acc;
Chris@49 328 }
Chris@49 329 else
Chris@49 330 if( (use_alpha == false) && (use_beta == true) )
Chris@49 331 {
Chris@49 332 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
Chris@49 333 }
Chris@49 334 else
Chris@49 335 if( (use_alpha == true) && (use_beta == true) )
Chris@49 336 {
Chris@49 337 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
Chris@49 338 }
Chris@49 339 }
Chris@49 340 }
Chris@49 341 }
Chris@49 342 else
Chris@49 343 if( (do_trans_A == true) && (do_trans_B == true) )
Chris@49 344 {
Chris@49 345 for(uword row_B=0; row_B < B_n_rows; ++row_B)
Chris@49 346 {
Chris@49 347
Chris@49 348 for(uword col_A=0; col_A < A_n_cols; ++col_A)
Chris@49 349 {
Chris@49 350 const in_eT1* A_coldata = A.colptr(col_A);
Chris@49 351
Chris@49 352 out_eT acc = out_eT(0);
Chris@49 353 for(uword i=0; i < A_n_rows; ++i)
Chris@49 354 {
Chris@49 355 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
Chris@49 356 }
Chris@49 357
Chris@49 358 if( (use_alpha == false) && (use_beta == false) )
Chris@49 359 {
Chris@49 360 C.at(col_A,row_B) = acc;
Chris@49 361 }
Chris@49 362 else
Chris@49 363 if( (use_alpha == true) && (use_beta == false) )
Chris@49 364 {
Chris@49 365 C.at(col_A,row_B) = alpha * acc;
Chris@49 366 }
Chris@49 367 else
Chris@49 368 if( (use_alpha == false) && (use_beta == true) )
Chris@49 369 {
Chris@49 370 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
Chris@49 371 }
Chris@49 372 else
Chris@49 373 if( (use_alpha == true) && (use_beta == true) )
Chris@49 374 {
Chris@49 375 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
Chris@49 376 }
Chris@49 377
Chris@49 378 }
Chris@49 379 }
Chris@49 380
Chris@49 381 }
Chris@49 382 }
Chris@49 383
Chris@49 384 };
Chris@49 385
Chris@49 386
Chris@49 387
Chris@49 388
Chris@49 389
Chris@49 390 //! \brief
Chris@49 391 //! Matrix multplication where the matrices have differing element types.
Chris@49 392
Chris@49 393 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
Chris@49 394 class gemm_mixed
Chris@49 395 {
Chris@49 396 public:
Chris@49 397
Chris@49 398 //! immediate multiplication of matrices A and B, storing the result in C
Chris@49 399 template<typename out_eT, typename in_eT1, typename in_eT2>
Chris@49 400 inline
Chris@49 401 static
Chris@49 402 void
Chris@49 403 apply
Chris@49 404 (
Chris@49 405 Mat<out_eT>& C,
Chris@49 406 const Mat<in_eT1>& A,
Chris@49 407 const Mat<in_eT2>& B,
Chris@49 408 const out_eT alpha = out_eT(1),
Chris@49 409 const out_eT beta = out_eT(0)
Chris@49 410 )
Chris@49 411 {
Chris@49 412 arma_extra_debug_sigprint();
Chris@49 413
Chris@49 414 Mat<in_eT1> tmp_A;
Chris@49 415 Mat<in_eT2> tmp_B;
Chris@49 416
Chris@49 417 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
Chris@49 418 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
Chris@49 419
Chris@49 420 if(do_trans_A)
Chris@49 421 {
Chris@49 422 op_htrans::apply_noalias(tmp_A, A);
Chris@49 423 }
Chris@49 424
Chris@49 425 if(do_trans_B)
Chris@49 426 {
Chris@49 427 op_htrans::apply_noalias(tmp_B, B);
Chris@49 428 }
Chris@49 429
Chris@49 430 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
Chris@49 431 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
Chris@49 432
Chris@49 433 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
Chris@49 434 {
Chris@49 435 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
Chris@49 436 }
Chris@49 437 else
Chris@49 438 {
Chris@49 439 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
Chris@49 440 }
Chris@49 441 }
Chris@49 442
Chris@49 443
Chris@49 444 };
Chris@49 445
Chris@49 446
Chris@49 447
Chris@49 448 //! @}