annotate armadillo-2.4.4/include/armadillo_bits/gemm_mixed.hpp @ 5:79b343f3e4b8

In thi version the problem of letters assigned to each segment has been solved.
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 13:48:13 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup gemm_mixed
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19 //! \brief
max@0 20 //! Matrix multplication where the matrices have differing element types.
max@0 21 //! Uses caching for speedup.
max@0 22 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
max@0 23
max@0 24 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 25 class gemm_mixed_large
max@0 26 {
max@0 27 public:
max@0 28
max@0 29 template<typename out_eT, typename in_eT1, typename in_eT2>
max@0 30 arma_hot
max@0 31 inline
max@0 32 static
max@0 33 void
max@0 34 apply
max@0 35 (
max@0 36 Mat<out_eT>& C,
max@0 37 const Mat<in_eT1>& A,
max@0 38 const Mat<in_eT2>& B,
max@0 39 const out_eT alpha = out_eT(1),
max@0 40 const out_eT beta = out_eT(0)
max@0 41 )
max@0 42 {
max@0 43 arma_extra_debug_sigprint();
max@0 44
max@0 45 const uword A_n_rows = A.n_rows;
max@0 46 const uword A_n_cols = A.n_cols;
max@0 47
max@0 48 const uword B_n_rows = B.n_rows;
max@0 49 const uword B_n_cols = B.n_cols;
max@0 50
max@0 51 if( (do_trans_A == false) && (do_trans_B == false) )
max@0 52 {
max@0 53 podarray<in_eT1> tmp(A_n_cols);
max@0 54 in_eT1* A_rowdata = tmp.memptr();
max@0 55
max@0 56 for(uword row_A=0; row_A < A_n_rows; ++row_A)
max@0 57 {
max@0 58 tmp.copy_row(A, row_A);
max@0 59
max@0 60 for(uword col_B=0; col_B < B_n_cols; ++col_B)
max@0 61 {
max@0 62 const in_eT2* B_coldata = B.colptr(col_B);
max@0 63
max@0 64 out_eT acc = out_eT(0);
max@0 65 for(uword i=0; i < B_n_rows; ++i)
max@0 66 {
max@0 67 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
max@0 68 }
max@0 69
max@0 70 if( (use_alpha == false) && (use_beta == false) )
max@0 71 {
max@0 72 C.at(row_A,col_B) = acc;
max@0 73 }
max@0 74 else
max@0 75 if( (use_alpha == true) && (use_beta == false) )
max@0 76 {
max@0 77 C.at(row_A,col_B) = alpha * acc;
max@0 78 }
max@0 79 else
max@0 80 if( (use_alpha == false) && (use_beta == true) )
max@0 81 {
max@0 82 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
max@0 83 }
max@0 84 else
max@0 85 if( (use_alpha == true) && (use_beta == true) )
max@0 86 {
max@0 87 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
max@0 88 }
max@0 89
max@0 90 }
max@0 91 }
max@0 92 }
max@0 93 else
max@0 94 if( (do_trans_A == true) && (do_trans_B == false) )
max@0 95 {
max@0 96 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 97 {
max@0 98 // col_A is interpreted as row_A when storing the results in matrix C
max@0 99
max@0 100 const in_eT1* A_coldata = A.colptr(col_A);
max@0 101
max@0 102 for(uword col_B=0; col_B < B_n_cols; ++col_B)
max@0 103 {
max@0 104 const in_eT2* B_coldata = B.colptr(col_B);
max@0 105
max@0 106 out_eT acc = out_eT(0);
max@0 107 for(uword i=0; i < B_n_rows; ++i)
max@0 108 {
max@0 109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
max@0 110 }
max@0 111
max@0 112 if( (use_alpha == false) && (use_beta == false) )
max@0 113 {
max@0 114 C.at(col_A,col_B) = acc;
max@0 115 }
max@0 116 else
max@0 117 if( (use_alpha == true) && (use_beta == false) )
max@0 118 {
max@0 119 C.at(col_A,col_B) = alpha * acc;
max@0 120 }
max@0 121 else
max@0 122 if( (use_alpha == false) && (use_beta == true) )
max@0 123 {
max@0 124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
max@0 125 }
max@0 126 else
max@0 127 if( (use_alpha == true) && (use_beta == true) )
max@0 128 {
max@0 129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
max@0 130 }
max@0 131
max@0 132 }
max@0 133 }
max@0 134 }
max@0 135 else
max@0 136 if( (do_trans_A == false) && (do_trans_B == true) )
max@0 137 {
max@0 138 Mat<in_eT2> B_tmp;
max@0 139
max@0 140 op_strans::apply_noalias(B_tmp, B);
max@0 141
max@0 142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
max@0 143 }
max@0 144 else
max@0 145 if( (do_trans_A == true) && (do_trans_B == true) )
max@0 146 {
max@0 147 // mat B_tmp = trans(B);
max@0 148 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
max@0 149
max@0 150
max@0 151 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
max@0 152 // transpose operations are not needed
max@0 153
max@0 154 podarray<in_eT2> tmp(B_n_cols);
max@0 155 in_eT2* B_rowdata = tmp.memptr();
max@0 156
max@0 157 for(uword row_B=0; row_B < B_n_rows; ++row_B)
max@0 158 {
max@0 159 tmp.copy_row(B, row_B);
max@0 160
max@0 161 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 162 {
max@0 163 const in_eT1* A_coldata = A.colptr(col_A);
max@0 164
max@0 165 out_eT acc = out_eT(0);
max@0 166 for(uword i=0; i < A_n_rows; ++i)
max@0 167 {
max@0 168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
max@0 169 }
max@0 170
max@0 171 if( (use_alpha == false) && (use_beta == false) )
max@0 172 {
max@0 173 C.at(col_A,row_B) = acc;
max@0 174 }
max@0 175 else
max@0 176 if( (use_alpha == true) && (use_beta == false) )
max@0 177 {
max@0 178 C.at(col_A,row_B) = alpha * acc;
max@0 179 }
max@0 180 else
max@0 181 if( (use_alpha == false) && (use_beta == true) )
max@0 182 {
max@0 183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
max@0 184 }
max@0 185 else
max@0 186 if( (use_alpha == true) && (use_beta == true) )
max@0 187 {
max@0 188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
max@0 189 }
max@0 190
max@0 191 }
max@0 192 }
max@0 193
max@0 194 }
max@0 195 }
max@0 196
max@0 197 };
max@0 198
max@0 199
max@0 200
max@0 201 //! Matrix multplication where the matrices have different element types.
max@0 202 //! Simple version (no caching).
max@0 203 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
max@0 204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 205 class gemm_mixed_small
max@0 206 {
max@0 207 public:
max@0 208
max@0 209 template<typename out_eT, typename in_eT1, typename in_eT2>
max@0 210 arma_hot
max@0 211 inline
max@0 212 static
max@0 213 void
max@0 214 apply
max@0 215 (
max@0 216 Mat<out_eT>& C,
max@0 217 const Mat<in_eT1>& A,
max@0 218 const Mat<in_eT2>& B,
max@0 219 const out_eT alpha = out_eT(1),
max@0 220 const out_eT beta = out_eT(0)
max@0 221 )
max@0 222 {
max@0 223 arma_extra_debug_sigprint();
max@0 224
max@0 225 const uword A_n_rows = A.n_rows;
max@0 226 const uword A_n_cols = A.n_cols;
max@0 227
max@0 228 const uword B_n_rows = B.n_rows;
max@0 229 const uword B_n_cols = B.n_cols;
max@0 230
max@0 231 if( (do_trans_A == false) && (do_trans_B == false) )
max@0 232 {
max@0 233 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
max@0 234 {
max@0 235 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
max@0 236 {
max@0 237 const in_eT2* B_coldata = B.colptr(col_B);
max@0 238
max@0 239 out_eT acc = out_eT(0);
max@0 240 for(uword i = 0; i < B_n_rows; ++i)
max@0 241 {
max@0 242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
max@0 243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
max@0 244 acc += val1 * val2;
max@0 245 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
max@0 246 }
max@0 247
max@0 248 if( (use_alpha == false) && (use_beta == false) )
max@0 249 {
max@0 250 C.at(row_A,col_B) = acc;
max@0 251 }
max@0 252 else
max@0 253 if( (use_alpha == true) && (use_beta == false) )
max@0 254 {
max@0 255 C.at(row_A,col_B) = alpha * acc;
max@0 256 }
max@0 257 else
max@0 258 if( (use_alpha == false) && (use_beta == true) )
max@0 259 {
max@0 260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
max@0 261 }
max@0 262 else
max@0 263 if( (use_alpha == true) && (use_beta == true) )
max@0 264 {
max@0 265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
max@0 266 }
max@0 267 }
max@0 268 }
max@0 269 }
max@0 270 else
max@0 271 if( (do_trans_A == true) && (do_trans_B == false) )
max@0 272 {
max@0 273 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 274 {
max@0 275 // col_A is interpreted as row_A when storing the results in matrix C
max@0 276
max@0 277 const in_eT1* A_coldata = A.colptr(col_A);
max@0 278
max@0 279 for(uword col_B=0; col_B < B_n_cols; ++col_B)
max@0 280 {
max@0 281 const in_eT2* B_coldata = B.colptr(col_B);
max@0 282
max@0 283 out_eT acc = out_eT(0);
max@0 284 for(uword i=0; i < B_n_rows; ++i)
max@0 285 {
max@0 286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
max@0 287 }
max@0 288
max@0 289 if( (use_alpha == false) && (use_beta == false) )
max@0 290 {
max@0 291 C.at(col_A,col_B) = acc;
max@0 292 }
max@0 293 else
max@0 294 if( (use_alpha == true) && (use_beta == false) )
max@0 295 {
max@0 296 C.at(col_A,col_B) = alpha * acc;
max@0 297 }
max@0 298 else
max@0 299 if( (use_alpha == false) && (use_beta == true) )
max@0 300 {
max@0 301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
max@0 302 }
max@0 303 else
max@0 304 if( (use_alpha == true) && (use_beta == true) )
max@0 305 {
max@0 306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
max@0 307 }
max@0 308
max@0 309 }
max@0 310 }
max@0 311 }
max@0 312 else
max@0 313 if( (do_trans_A == false) && (do_trans_B == true) )
max@0 314 {
max@0 315 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
max@0 316 {
max@0 317 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
max@0 318 {
max@0 319 out_eT acc = out_eT(0);
max@0 320 for(uword i = 0; i < B_n_cols; ++i)
max@0 321 {
max@0 322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
max@0 323 }
max@0 324
max@0 325 if( (use_alpha == false) && (use_beta == false) )
max@0 326 {
max@0 327 C.at(row_A,row_B) = acc;
max@0 328 }
max@0 329 else
max@0 330 if( (use_alpha == true) && (use_beta == false) )
max@0 331 {
max@0 332 C.at(row_A,row_B) = alpha * acc;
max@0 333 }
max@0 334 else
max@0 335 if( (use_alpha == false) && (use_beta == true) )
max@0 336 {
max@0 337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
max@0 338 }
max@0 339 else
max@0 340 if( (use_alpha == true) && (use_beta == true) )
max@0 341 {
max@0 342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
max@0 343 }
max@0 344 }
max@0 345 }
max@0 346 }
max@0 347 else
max@0 348 if( (do_trans_A == true) && (do_trans_B == true) )
max@0 349 {
max@0 350 for(uword row_B=0; row_B < B_n_rows; ++row_B)
max@0 351 {
max@0 352
max@0 353 for(uword col_A=0; col_A < A_n_cols; ++col_A)
max@0 354 {
max@0 355 const in_eT1* A_coldata = A.colptr(col_A);
max@0 356
max@0 357 out_eT acc = out_eT(0);
max@0 358 for(uword i=0; i < A_n_rows; ++i)
max@0 359 {
max@0 360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
max@0 361 }
max@0 362
max@0 363 if( (use_alpha == false) && (use_beta == false) )
max@0 364 {
max@0 365 C.at(col_A,row_B) = acc;
max@0 366 }
max@0 367 else
max@0 368 if( (use_alpha == true) && (use_beta == false) )
max@0 369 {
max@0 370 C.at(col_A,row_B) = alpha * acc;
max@0 371 }
max@0 372 else
max@0 373 if( (use_alpha == false) && (use_beta == true) )
max@0 374 {
max@0 375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
max@0 376 }
max@0 377 else
max@0 378 if( (use_alpha == true) && (use_beta == true) )
max@0 379 {
max@0 380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
max@0 381 }
max@0 382
max@0 383 }
max@0 384 }
max@0 385
max@0 386 }
max@0 387 }
max@0 388
max@0 389 };
max@0 390
max@0 391
max@0 392
max@0 393
max@0 394
max@0 395 //! \brief
max@0 396 //! Matrix multplication where the matrices have differing element types.
max@0 397
max@0 398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
max@0 399 class gemm_mixed
max@0 400 {
max@0 401 public:
max@0 402
max@0 403 //! immediate multiplication of matrices A and B, storing the result in C
max@0 404 template<typename out_eT, typename in_eT1, typename in_eT2>
max@0 405 inline
max@0 406 static
max@0 407 void
max@0 408 apply
max@0 409 (
max@0 410 Mat<out_eT>& C,
max@0 411 const Mat<in_eT1>& A,
max@0 412 const Mat<in_eT2>& B,
max@0 413 const out_eT alpha = out_eT(1),
max@0 414 const out_eT beta = out_eT(0)
max@0 415 )
max@0 416 {
max@0 417 arma_extra_debug_sigprint();
max@0 418
max@0 419 Mat<in_eT1> tmp_A;
max@0 420 Mat<in_eT2> tmp_B;
max@0 421
max@0 422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
max@0 423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
max@0 424
max@0 425 if(do_trans_A)
max@0 426 {
max@0 427 op_htrans::apply_noalias(tmp_A, A);
max@0 428 }
max@0 429
max@0 430 if(do_trans_B)
max@0 431 {
max@0 432 op_htrans::apply_noalias(tmp_B, B);
max@0 433 }
max@0 434
max@0 435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
max@0 436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
max@0 437
max@0 438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
max@0 439 {
max@0 440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
max@0 441 }
max@0 442 else
max@0 443 {
max@0 444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
max@0 445 }
max@0 446 }
max@0 447
max@0 448
max@0 449 };
max@0 450
max@0 451
max@0 452
max@0 453 //! @}