annotate armadillo-2.4.4/include/armadillo_bits/gemv.hpp @ 18:8d046a9d36aa slimline

Back out rev 13:ac07c60aa798. Like an idiot, I committed a whole pile of unrelated changes in the guise of a single typo fix. Will re-commit in stages
author Chris Cannam
date Thu, 10 May 2012 10:45:44 +0100
parents 8b6102e2a9b0
children
rev   line source
max@0 1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
max@0 2 // Copyright (C) 2008-2011 Conrad Sanderson
max@0 3 //
max@0 4 // This file is part of the Armadillo C++ library.
max@0 5 // It is provided without any warranty of fitness
max@0 6 // for any purpose. You can redistribute this file
max@0 7 // and/or modify it under the terms of the GNU
max@0 8 // Lesser General Public License (LGPL) as published
max@0 9 // by the Free Software Foundation, either version 3
max@0 10 // of the License or (at your option) any later version.
max@0 11 // (see http://www.opensource.org/licenses for more info)
max@0 12
max@0 13
max@0 14 //! \addtogroup gemv
max@0 15 //! @{
max@0 16
max@0 17
max@0 18
max@0 19 //! for tiny square matrices, size <= 4x4
max@0 20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
max@0 21 class gemv_emul_tinysq
max@0 22 {
max@0 23 public:
max@0 24
max@0 25
max@0 26 template<const uword row, const uword col>
max@0 27 struct pos
max@0 28 {
max@0 29 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
max@0 30 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
max@0 31 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
max@0 32 };
max@0 33
max@0 34
max@0 35
max@0 36 template<typename eT, const uword i>
max@0 37 arma_hot
max@0 38 arma_inline
max@0 39 static
max@0 40 void
max@0 41 assign(eT* y, const eT acc, const eT alpha, const eT beta)
max@0 42 {
max@0 43 if(use_beta == false)
max@0 44 {
max@0 45 y[i] = (use_alpha == false) ? acc : alpha*acc;
max@0 46 }
max@0 47 else
max@0 48 {
max@0 49 const eT tmp = y[i];
max@0 50
max@0 51 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
max@0 52 }
max@0 53 }
max@0 54
max@0 55
max@0 56
max@0 57 template<typename eT>
max@0 58 arma_hot
max@0 59 inline
max@0 60 static
max@0 61 void
max@0 62 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 63 {
max@0 64 arma_extra_debug_sigprint();
max@0 65
max@0 66 const eT* Am = A.memptr();
max@0 67
max@0 68 switch(A.n_rows)
max@0 69 {
max@0 70 case 1:
max@0 71 {
max@0 72 const eT acc = Am[0] * x[0];
max@0 73
max@0 74 assign<eT, 0>(y, acc, alpha, beta);
max@0 75 }
max@0 76 break;
max@0 77
max@0 78
max@0 79 case 2:
max@0 80 {
max@0 81 const eT x0 = x[0];
max@0 82 const eT x1 = x[1];
max@0 83
max@0 84 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
max@0 85 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
max@0 86
max@0 87 assign<eT, 0>(y, acc0, alpha, beta);
max@0 88 assign<eT, 1>(y, acc1, alpha, beta);
max@0 89 }
max@0 90 break;
max@0 91
max@0 92
max@0 93 case 3:
max@0 94 {
max@0 95 const eT x0 = x[0];
max@0 96 const eT x1 = x[1];
max@0 97 const eT x2 = x[2];
max@0 98
max@0 99 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
max@0 100 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
max@0 101 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
max@0 102
max@0 103 assign<eT, 0>(y, acc0, alpha, beta);
max@0 104 assign<eT, 1>(y, acc1, alpha, beta);
max@0 105 assign<eT, 2>(y, acc2, alpha, beta);
max@0 106 }
max@0 107 break;
max@0 108
max@0 109
max@0 110 case 4:
max@0 111 {
max@0 112 const eT x0 = x[0];
max@0 113 const eT x1 = x[1];
max@0 114 const eT x2 = x[2];
max@0 115 const eT x3 = x[3];
max@0 116
max@0 117 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
max@0 118 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
max@0 119 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
max@0 120 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
max@0 121
max@0 122 assign<eT, 0>(y, acc0, alpha, beta);
max@0 123 assign<eT, 1>(y, acc1, alpha, beta);
max@0 124 assign<eT, 2>(y, acc2, alpha, beta);
max@0 125 assign<eT, 3>(y, acc3, alpha, beta);
max@0 126 }
max@0 127 break;
max@0 128
max@0 129
max@0 130 default:
max@0 131 ;
max@0 132 }
max@0 133 }
max@0 134
max@0 135 };
max@0 136
max@0 137
max@0 138
max@0 139 //! \brief
max@0 140 //! Partial emulation of ATLAS/BLAS gemv().
max@0 141 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
max@0 142
max@0 143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
max@0 144 class gemv_emul_large
max@0 145 {
max@0 146 public:
max@0 147
max@0 148 template<typename eT>
max@0 149 arma_hot
max@0 150 inline
max@0 151 static
max@0 152 void
max@0 153 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 154 {
max@0 155 arma_extra_debug_sigprint();
max@0 156
max@0 157 const uword A_n_rows = A.n_rows;
max@0 158 const uword A_n_cols = A.n_cols;
max@0 159
max@0 160 if(do_trans_A == false)
max@0 161 {
max@0 162 if(A_n_rows == 1)
max@0 163 {
max@0 164 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
max@0 165
max@0 166 if( (use_alpha == false) && (use_beta == false) )
max@0 167 {
max@0 168 y[0] = acc;
max@0 169 }
max@0 170 else
max@0 171 if( (use_alpha == true) && (use_beta == false) )
max@0 172 {
max@0 173 y[0] = alpha * acc;
max@0 174 }
max@0 175 else
max@0 176 if( (use_alpha == false) && (use_beta == true) )
max@0 177 {
max@0 178 y[0] = acc + beta*y[0];
max@0 179 }
max@0 180 else
max@0 181 if( (use_alpha == true) && (use_beta == true) )
max@0 182 {
max@0 183 y[0] = alpha*acc + beta*y[0];
max@0 184 }
max@0 185 }
max@0 186 else
max@0 187 for(uword row=0; row < A_n_rows; ++row)
max@0 188 {
max@0 189 eT acc = eT(0);
max@0 190
max@0 191 for(uword i=0; i < A_n_cols; ++i)
max@0 192 {
max@0 193 acc += A.at(row,i) * x[i];
max@0 194 }
max@0 195
max@0 196 if( (use_alpha == false) && (use_beta == false) )
max@0 197 {
max@0 198 y[row] = acc;
max@0 199 }
max@0 200 else
max@0 201 if( (use_alpha == true) && (use_beta == false) )
max@0 202 {
max@0 203 y[row] = alpha * acc;
max@0 204 }
max@0 205 else
max@0 206 if( (use_alpha == false) && (use_beta == true) )
max@0 207 {
max@0 208 y[row] = acc + beta*y[row];
max@0 209 }
max@0 210 else
max@0 211 if( (use_alpha == true) && (use_beta == true) )
max@0 212 {
max@0 213 y[row] = alpha*acc + beta*y[row];
max@0 214 }
max@0 215 }
max@0 216 }
max@0 217 else
max@0 218 if(do_trans_A == true)
max@0 219 {
max@0 220 for(uword col=0; col < A_n_cols; ++col)
max@0 221 {
max@0 222 // col is interpreted as row when storing the results in 'y'
max@0 223
max@0 224
max@0 225 // const eT* A_coldata = A.colptr(col);
max@0 226 //
max@0 227 // eT acc = eT(0);
max@0 228 // for(uword row=0; row < A_n_rows; ++row)
max@0 229 // {
max@0 230 // acc += A_coldata[row] * x[row];
max@0 231 // }
max@0 232
max@0 233 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
max@0 234
max@0 235 if( (use_alpha == false) && (use_beta == false) )
max@0 236 {
max@0 237 y[col] = acc;
max@0 238 }
max@0 239 else
max@0 240 if( (use_alpha == true) && (use_beta == false) )
max@0 241 {
max@0 242 y[col] = alpha * acc;
max@0 243 }
max@0 244 else
max@0 245 if( (use_alpha == false) && (use_beta == true) )
max@0 246 {
max@0 247 y[col] = acc + beta*y[col];
max@0 248 }
max@0 249 else
max@0 250 if( (use_alpha == true) && (use_beta == true) )
max@0 251 {
max@0 252 y[col] = alpha*acc + beta*y[col];
max@0 253 }
max@0 254
max@0 255 }
max@0 256 }
max@0 257 }
max@0 258
max@0 259 };
max@0 260
max@0 261
max@0 262
max@0 263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
max@0 264 class gemv_emul
max@0 265 {
max@0 266 public:
max@0 267
max@0 268 template<typename eT>
max@0 269 arma_hot
max@0 270 inline
max@0 271 static
max@0 272 void
max@0 273 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
max@0 274 {
max@0 275 arma_extra_debug_sigprint();
max@0 276 arma_ignore(junk);
max@0 277
max@0 278 const uword A_n_rows = A.n_rows;
max@0 279 const uword A_n_cols = A.n_cols;
max@0 280
max@0 281 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
max@0 282 {
max@0 283 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
max@0 284 }
max@0 285 else
max@0 286 {
max@0 287 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
max@0 288 }
max@0 289 }
max@0 290
max@0 291
max@0 292
max@0 293 template<typename eT>
max@0 294 arma_hot
max@0 295 inline
max@0 296 static
max@0 297 void
max@0 298 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
max@0 299 {
max@0 300 arma_extra_debug_sigprint();
max@0 301
max@0 302 Mat<eT> tmp_A;
max@0 303
max@0 304 if(do_trans_A)
max@0 305 {
max@0 306 op_htrans::apply_noalias(tmp_A, A);
max@0 307 }
max@0 308
max@0 309 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
max@0 310
max@0 311 const uword AA_n_rows = AA.n_rows;
max@0 312 const uword AA_n_cols = AA.n_cols;
max@0 313
max@0 314 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
max@0 315 {
max@0 316 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
max@0 317 }
max@0 318 else
max@0 319 {
max@0 320 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
max@0 321 }
max@0 322 }
max@0 323 };
max@0 324
max@0 325
max@0 326
max@0 327 //! \brief
max@0 328 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
max@0 329 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
max@0 330
max@0 331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
max@0 332 class gemv
max@0 333 {
max@0 334 public:
max@0 335
max@0 336 template<typename eT>
max@0 337 inline
max@0 338 static
max@0 339 void
max@0 340 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 341 {
max@0 342 arma_extra_debug_sigprint();
max@0 343
max@0 344 if(A.n_elem <= 64u)
max@0 345 {
max@0 346 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
max@0 347 }
max@0 348 else
max@0 349 {
max@0 350 #if defined(ARMA_USE_ATLAS)
max@0 351 {
max@0 352 arma_extra_debug_print("atlas::cblas_gemv()");
max@0 353
max@0 354 atlas::cblas_gemv<eT>
max@0 355 (
max@0 356 atlas::CblasColMajor,
max@0 357 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
max@0 358 A.n_rows,
max@0 359 A.n_cols,
max@0 360 (use_alpha) ? alpha : eT(1),
max@0 361 A.mem,
max@0 362 A.n_rows,
max@0 363 x,
max@0 364 1,
max@0 365 (use_beta) ? beta : eT(0),
max@0 366 y,
max@0 367 1
max@0 368 );
max@0 369 }
max@0 370 #elif defined(ARMA_USE_BLAS)
max@0 371 {
max@0 372 arma_extra_debug_print("blas::gemv()");
max@0 373
max@0 374 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
max@0 375 const blas_int m = A.n_rows;
max@0 376 const blas_int n = A.n_cols;
max@0 377 const eT local_alpha = (use_alpha) ? alpha : eT(1);
max@0 378 //const blas_int lda = A.n_rows;
max@0 379 const blas_int inc = 1;
max@0 380 const eT local_beta = (use_beta) ? beta : eT(0);
max@0 381
max@0 382 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
max@0 383
max@0 384 blas::gemv<eT>
max@0 385 (
max@0 386 &trans_A,
max@0 387 &m,
max@0 388 &n,
max@0 389 &local_alpha,
max@0 390 A.mem,
max@0 391 &m, // lda
max@0 392 x,
max@0 393 &inc,
max@0 394 &local_beta,
max@0 395 y,
max@0 396 &inc
max@0 397 );
max@0 398 }
max@0 399 #else
max@0 400 {
max@0 401 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
max@0 402 }
max@0 403 #endif
max@0 404 }
max@0 405
max@0 406 }
max@0 407
max@0 408
max@0 409
max@0 410 template<typename eT>
max@0 411 arma_inline
max@0 412 static
max@0 413 void
max@0 414 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
max@0 415 {
max@0 416 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
max@0 417 }
max@0 418
max@0 419
max@0 420
max@0 421 arma_inline
max@0 422 static
max@0 423 void
max@0 424 apply
max@0 425 (
max@0 426 float* y,
max@0 427 const Mat<float>& A,
max@0 428 const float* x,
max@0 429 const float alpha = float(1),
max@0 430 const float beta = float(0)
max@0 431 )
max@0 432 {
max@0 433 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
max@0 434 }
max@0 435
max@0 436
max@0 437
max@0 438 arma_inline
max@0 439 static
max@0 440 void
max@0 441 apply
max@0 442 (
max@0 443 double* y,
max@0 444 const Mat<double>& A,
max@0 445 const double* x,
max@0 446 const double alpha = double(1),
max@0 447 const double beta = double(0)
max@0 448 )
max@0 449 {
max@0 450 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
max@0 451 }
max@0 452
max@0 453
max@0 454
max@0 455 arma_inline
max@0 456 static
max@0 457 void
max@0 458 apply
max@0 459 (
max@0 460 std::complex<float>* y,
max@0 461 const Mat< std::complex<float > >& A,
max@0 462 const std::complex<float>* x,
max@0 463 const std::complex<float> alpha = std::complex<float>(1),
max@0 464 const std::complex<float> beta = std::complex<float>(0)
max@0 465 )
max@0 466 {
max@0 467 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
max@0 468 }
max@0 469
max@0 470
max@0 471
max@0 472 arma_inline
max@0 473 static
max@0 474 void
max@0 475 apply
max@0 476 (
max@0 477 std::complex<double>* y,
max@0 478 const Mat< std::complex<double> >& A,
max@0 479 const std::complex<double>* x,
max@0 480 const std::complex<double> alpha = std::complex<double>(1),
max@0 481 const std::complex<double> beta = std::complex<double>(0)
max@0 482 )
max@0 483 {
max@0 484 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
max@0 485 }
max@0 486
max@0 487
max@0 488
max@0 489 };
max@0 490
max@0 491
max@0 492 //! @}