comparison armadillo-3.900.4/include/armadillo_bits/gemv.hpp @ 49:1ec0e2823891

Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author Chris Cannam
date Thu, 13 Jun 2013 10:25:24 +0100
parents
children
comparison
equal deleted inserted replaced
48:69251e11a913 49:1ec0e2823891
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2013 Conrad Sanderson
3 //
4 // This Source Code Form is subject to the terms of the Mozilla Public
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9 //! \addtogroup gemv
10 //! @{
11
12
13
14 //! for tiny square matrices, size <= 4x4
15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
16 class gemv_emul_tinysq
17 {
18 public:
19
20
21 template<const uword row, const uword col>
22 struct pos
23 {
24 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
25 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
26 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
27 };
28
29
30
31 template<typename eT, const uword i>
32 arma_hot
33 arma_inline
34 static
35 void
36 assign(eT* y, const eT acc, const eT alpha, const eT beta)
37 {
38 if(use_beta == false)
39 {
40 y[i] = (use_alpha == false) ? acc : alpha*acc;
41 }
42 else
43 {
44 const eT tmp = y[i];
45
46 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
47 }
48 }
49
50
51
52 template<typename eT, typename TA>
53 arma_hot
54 inline
55 static
56 void
57 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
58 {
59 arma_extra_debug_sigprint();
60
61 const eT* Am = A.memptr();
62
63 switch(A.n_rows)
64 {
65 case 1:
66 {
67 const eT acc = Am[0] * x[0];
68
69 assign<eT, 0>(y, acc, alpha, beta);
70 }
71 break;
72
73
74 case 2:
75 {
76 const eT x0 = x[0];
77 const eT x1 = x[1];
78
79 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
80 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
81
82 assign<eT, 0>(y, acc0, alpha, beta);
83 assign<eT, 1>(y, acc1, alpha, beta);
84 }
85 break;
86
87
88 case 3:
89 {
90 const eT x0 = x[0];
91 const eT x1 = x[1];
92 const eT x2 = x[2];
93
94 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
95 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
96 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
97
98 assign<eT, 0>(y, acc0, alpha, beta);
99 assign<eT, 1>(y, acc1, alpha, beta);
100 assign<eT, 2>(y, acc2, alpha, beta);
101 }
102 break;
103
104
105 case 4:
106 {
107 const eT x0 = x[0];
108 const eT x1 = x[1];
109 const eT x2 = x[2];
110 const eT x3 = x[3];
111
112 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
113 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
114 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
115 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
116
117 assign<eT, 0>(y, acc0, alpha, beta);
118 assign<eT, 1>(y, acc1, alpha, beta);
119 assign<eT, 2>(y, acc2, alpha, beta);
120 assign<eT, 3>(y, acc3, alpha, beta);
121 }
122 break;
123
124
125 default:
126 ;
127 }
128 }
129
130 };
131
132
133
134 class gemv_emul_large_helper
135 {
136 public:
137
138 template<typename eT, typename TA>
139 arma_hot
140 inline
141 static
142 typename arma_not_cx<eT>::result
143 dot_row_col( const TA& A, const eT* x, const uword row, const uword N )
144 {
145 eT acc1 = eT(0);
146 eT acc2 = eT(0);
147
148 uword i,j;
149 for(i=0, j=1; j < N; i+=2, j+=2)
150 {
151 const eT xi = x[i];
152 const eT xj = x[j];
153
154 acc1 += A.at(row,i) * xi;
155 acc2 += A.at(row,j) * xj;
156 }
157
158 if(i < N)
159 {
160 acc1 += A.at(row,i) * x[i];
161 }
162
163 return (acc1 + acc2);
164 }
165
166
167
168 template<typename eT, typename TA>
169 arma_hot
170 inline
171 static
172 typename arma_cx_only<eT>::result
173 dot_row_col( const TA& A, const eT* x, const uword row, const uword N )
174 {
175 typedef typename get_pod_type<eT>::result T;
176
177 T val_real = T(0);
178 T val_imag = T(0);
179
180 for(uword i=0; i<N; ++i)
181 {
182 const std::complex<T>& Ai = A.at(row,i);
183 const std::complex<T>& xi = x[i];
184
185 const T a = Ai.real();
186 const T b = Ai.imag();
187
188 const T c = xi.real();
189 const T d = xi.imag();
190
191 val_real += (a*c) - (b*d);
192 val_imag += (a*d) + (b*c);
193 }
194
195 return std::complex<T>(val_real, val_imag);
196 }
197
198 };
199
200
201
202 //! \brief
203 //! Partial emulation of ATLAS/BLAS gemv().
204 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
205
206 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
207 class gemv_emul_large
208 {
209 public:
210
211 template<typename eT, typename TA>
212 arma_hot
213 inline
214 static
215 void
216 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
217 {
218 arma_extra_debug_sigprint();
219
220 const uword A_n_rows = A.n_rows;
221 const uword A_n_cols = A.n_cols;
222
223 if(do_trans_A == false)
224 {
225 if(A_n_rows == 1)
226 {
227 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.memptr(), x);
228
229 if( (use_alpha == false) && (use_beta == false) )
230 {
231 y[0] = acc;
232 }
233 else
234 if( (use_alpha == true) && (use_beta == false) )
235 {
236 y[0] = alpha * acc;
237 }
238 else
239 if( (use_alpha == false) && (use_beta == true) )
240 {
241 y[0] = acc + beta*y[0];
242 }
243 else
244 if( (use_alpha == true) && (use_beta == true) )
245 {
246 y[0] = alpha*acc + beta*y[0];
247 }
248 }
249 else
250 for(uword row=0; row < A_n_rows; ++row)
251 {
252 const eT acc = gemv_emul_large_helper::dot_row_col(A, x, row, A_n_cols);
253
254 if( (use_alpha == false) && (use_beta == false) )
255 {
256 y[row] = acc;
257 }
258 else
259 if( (use_alpha == true) && (use_beta == false) )
260 {
261 y[row] = alpha * acc;
262 }
263 else
264 if( (use_alpha == false) && (use_beta == true) )
265 {
266 y[row] = acc + beta*y[row];
267 }
268 else
269 if( (use_alpha == true) && (use_beta == true) )
270 {
271 y[row] = alpha*acc + beta*y[row];
272 }
273 }
274 }
275 else
276 if(do_trans_A == true)
277 {
278 for(uword col=0; col < A_n_cols; ++col)
279 {
280 // col is interpreted as row when storing the results in 'y'
281
282
283 // const eT* A_coldata = A.colptr(col);
284 //
285 // eT acc = eT(0);
286 // for(uword row=0; row < A_n_rows; ++row)
287 // {
288 // acc += A_coldata[row] * x[row];
289 // }
290
291 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
292
293 if( (use_alpha == false) && (use_beta == false) )
294 {
295 y[col] = acc;
296 }
297 else
298 if( (use_alpha == true) && (use_beta == false) )
299 {
300 y[col] = alpha * acc;
301 }
302 else
303 if( (use_alpha == false) && (use_beta == true) )
304 {
305 y[col] = acc + beta*y[col];
306 }
307 else
308 if( (use_alpha == true) && (use_beta == true) )
309 {
310 y[col] = alpha*acc + beta*y[col];
311 }
312
313 }
314 }
315 }
316
317 };
318
319
320
321 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
322 class gemv_emul
323 {
324 public:
325
326 template<typename eT, typename TA>
327 arma_hot
328 inline
329 static
330 void
331 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
332 {
333 arma_extra_debug_sigprint();
334 arma_ignore(junk);
335
336 const uword A_n_rows = A.n_rows;
337 const uword A_n_cols = A.n_cols;
338
339 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
340 {
341 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
342 }
343 else
344 {
345 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
346 }
347 }
348
349
350
351 template<typename eT>
352 arma_hot
353 inline
354 static
355 void
356 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
357 {
358 arma_extra_debug_sigprint();
359 arma_ignore(junk);
360
361 Mat<eT> tmp_A;
362
363 if(do_trans_A)
364 {
365 op_htrans::apply_noalias(tmp_A, A);
366 }
367
368 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
369
370 const uword AA_n_rows = AA.n_rows;
371 const uword AA_n_cols = AA.n_cols;
372
373 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
374 {
375 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
376 }
377 else
378 {
379 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
380 }
381 }
382 };
383
384
385
386 //! \brief
387 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
388 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
389
390 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
391 class gemv
392 {
393 public:
394
395 template<typename eT, typename TA>
396 inline
397 static
398 void
399 apply_blas_type( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
400 {
401 arma_extra_debug_sigprint();
402
403 //const uword threshold = (is_complex<eT>::value == true) ? 16u : 64u;
404 const uword threshold = (is_complex<eT>::value == true) ? 64u : 100u;
405
406 if(A.n_elem <= threshold)
407 {
408 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
409 }
410 else
411 {
412 #if defined(ARMA_USE_ATLAS)
413 {
414 if(is_complex<eT>::value == false)
415 {
416 // use gemm() instead of gemv() to work around a speed issue in Atlas 3.8.4
417
418 arma_extra_debug_print("atlas::cblas_gemm()");
419
420 atlas::cblas_gemm<eT>
421 (
422 atlas::CblasColMajor,
423 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
424 atlas::CblasNoTrans,
425 (do_trans_A) ? A.n_cols : A.n_rows,
426 1,
427 (do_trans_A) ? A.n_rows : A.n_cols,
428 (use_alpha) ? alpha : eT(1),
429 A.mem,
430 A.n_rows,
431 x,
432 (do_trans_A) ? A.n_rows : A.n_cols,
433 (use_beta) ? beta : eT(0),
434 y,
435 (do_trans_A) ? A.n_cols : A.n_rows
436 );
437 }
438 else
439 {
440 arma_extra_debug_print("atlas::cblas_gemv()");
441
442 atlas::cblas_gemv<eT>
443 (
444 atlas::CblasColMajor,
445 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
446 A.n_rows,
447 A.n_cols,
448 (use_alpha) ? alpha : eT(1),
449 A.mem,
450 A.n_rows,
451 x,
452 1,
453 (use_beta) ? beta : eT(0),
454 y,
455 1
456 );
457 }
458 }
459 #elif defined(ARMA_USE_BLAS)
460 {
461 arma_extra_debug_print("blas::gemv()");
462
463 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
464 const blas_int m = A.n_rows;
465 const blas_int n = A.n_cols;
466 const eT local_alpha = (use_alpha) ? alpha : eT(1);
467 //const blas_int lda = A.n_rows;
468 const blas_int inc = 1;
469 const eT local_beta = (use_beta) ? beta : eT(0);
470
471 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
472
473 blas::gemv<eT>
474 (
475 &trans_A,
476 &m,
477 &n,
478 &local_alpha,
479 A.mem,
480 &m, // lda
481 x,
482 &inc,
483 &local_beta,
484 y,
485 &inc
486 );
487 }
488 #else
489 {
490 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
491 }
492 #endif
493 }
494
495 }
496
497
498
499 template<typename eT, typename TA>
500 arma_inline
501 static
502 void
503 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
504 {
505 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
506 }
507
508
509
510 template<typename TA>
511 arma_inline
512 static
513 void
514 apply
515 (
516 float* y,
517 const TA& A,
518 const float* x,
519 const float alpha = float(1),
520 const float beta = float(0)
521 )
522 {
523 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
524 }
525
526
527
528 template<typename TA>
529 arma_inline
530 static
531 void
532 apply
533 (
534 double* y,
535 const TA& A,
536 const double* x,
537 const double alpha = double(1),
538 const double beta = double(0)
539 )
540 {
541 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
542 }
543
544
545
546 template<typename TA>
547 arma_inline
548 static
549 void
550 apply
551 (
552 std::complex<float>* y,
553 const TA& A,
554 const std::complex<float>* x,
555 const std::complex<float> alpha = std::complex<float>(1),
556 const std::complex<float> beta = std::complex<float>(0)
557 )
558 {
559 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
560 }
561
562
563
564 template<typename TA>
565 arma_inline
566 static
567 void
568 apply
569 (
570 std::complex<double>* y,
571 const TA& A,
572 const std::complex<double>* x,
573 const std::complex<double> alpha = std::complex<double>(1),
574 const std::complex<double> beta = std::complex<double>(0)
575 )
576 {
577 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
578 }
579
580
581
582 };
583
584
585 //! @}