Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-3.900.4/include/armadillo_bits/gemv.hpp @ 49:1ec0e2823891
Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author | Chris Cannam |
---|---|
date | Thu, 13 Jun 2013 10:25:24 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
48:69251e11a913 | 49:1ec0e2823891 |
---|---|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2013 Conrad Sanderson | |
3 // | |
4 // This Source Code Form is subject to the terms of the Mozilla Public | |
5 // License, v. 2.0. If a copy of the MPL was not distributed with this | |
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |
7 | |
8 | |
9 //! \addtogroup gemv | |
10 //! @{ | |
11 | |
12 | |
13 | |
14 //! for tiny square matrices, size <= 4x4 | |
15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
16 class gemv_emul_tinysq | |
17 { | |
18 public: | |
19 | |
20 | |
21 template<const uword row, const uword col> | |
22 struct pos | |
23 { | |
24 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2); | |
25 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3); | |
26 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4); | |
27 }; | |
28 | |
29 | |
30 | |
31 template<typename eT, const uword i> | |
32 arma_hot | |
33 arma_inline | |
34 static | |
35 void | |
36 assign(eT* y, const eT acc, const eT alpha, const eT beta) | |
37 { | |
38 if(use_beta == false) | |
39 { | |
40 y[i] = (use_alpha == false) ? acc : alpha*acc; | |
41 } | |
42 else | |
43 { | |
44 const eT tmp = y[i]; | |
45 | |
46 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc ); | |
47 } | |
48 } | |
49 | |
50 | |
51 | |
52 template<typename eT, typename TA> | |
53 arma_hot | |
54 inline | |
55 static | |
56 void | |
57 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
58 { | |
59 arma_extra_debug_sigprint(); | |
60 | |
61 const eT* Am = A.memptr(); | |
62 | |
63 switch(A.n_rows) | |
64 { | |
65 case 1: | |
66 { | |
67 const eT acc = Am[0] * x[0]; | |
68 | |
69 assign<eT, 0>(y, acc, alpha, beta); | |
70 } | |
71 break; | |
72 | |
73 | |
74 case 2: | |
75 { | |
76 const eT x0 = x[0]; | |
77 const eT x1 = x[1]; | |
78 | |
79 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1; | |
80 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1; | |
81 | |
82 assign<eT, 0>(y, acc0, alpha, beta); | |
83 assign<eT, 1>(y, acc1, alpha, beta); | |
84 } | |
85 break; | |
86 | |
87 | |
88 case 3: | |
89 { | |
90 const eT x0 = x[0]; | |
91 const eT x1 = x[1]; | |
92 const eT x2 = x[2]; | |
93 | |
94 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2; | |
95 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2; | |
96 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2; | |
97 | |
98 assign<eT, 0>(y, acc0, alpha, beta); | |
99 assign<eT, 1>(y, acc1, alpha, beta); | |
100 assign<eT, 2>(y, acc2, alpha, beta); | |
101 } | |
102 break; | |
103 | |
104 | |
105 case 4: | |
106 { | |
107 const eT x0 = x[0]; | |
108 const eT x1 = x[1]; | |
109 const eT x2 = x[2]; | |
110 const eT x3 = x[3]; | |
111 | |
112 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3; | |
113 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3; | |
114 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3; | |
115 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3; | |
116 | |
117 assign<eT, 0>(y, acc0, alpha, beta); | |
118 assign<eT, 1>(y, acc1, alpha, beta); | |
119 assign<eT, 2>(y, acc2, alpha, beta); | |
120 assign<eT, 3>(y, acc3, alpha, beta); | |
121 } | |
122 break; | |
123 | |
124 | |
125 default: | |
126 ; | |
127 } | |
128 } | |
129 | |
130 }; | |
131 | |
132 | |
133 | |
134 class gemv_emul_large_helper | |
135 { | |
136 public: | |
137 | |
138 template<typename eT, typename TA> | |
139 arma_hot | |
140 inline | |
141 static | |
142 typename arma_not_cx<eT>::result | |
143 dot_row_col( const TA& A, const eT* x, const uword row, const uword N ) | |
144 { | |
145 eT acc1 = eT(0); | |
146 eT acc2 = eT(0); | |
147 | |
148 uword i,j; | |
149 for(i=0, j=1; j < N; i+=2, j+=2) | |
150 { | |
151 const eT xi = x[i]; | |
152 const eT xj = x[j]; | |
153 | |
154 acc1 += A.at(row,i) * xi; | |
155 acc2 += A.at(row,j) * xj; | |
156 } | |
157 | |
158 if(i < N) | |
159 { | |
160 acc1 += A.at(row,i) * x[i]; | |
161 } | |
162 | |
163 return (acc1 + acc2); | |
164 } | |
165 | |
166 | |
167 | |
168 template<typename eT, typename TA> | |
169 arma_hot | |
170 inline | |
171 static | |
172 typename arma_cx_only<eT>::result | |
173 dot_row_col( const TA& A, const eT* x, const uword row, const uword N ) | |
174 { | |
175 typedef typename get_pod_type<eT>::result T; | |
176 | |
177 T val_real = T(0); | |
178 T val_imag = T(0); | |
179 | |
180 for(uword i=0; i<N; ++i) | |
181 { | |
182 const std::complex<T>& Ai = A.at(row,i); | |
183 const std::complex<T>& xi = x[i]; | |
184 | |
185 const T a = Ai.real(); | |
186 const T b = Ai.imag(); | |
187 | |
188 const T c = xi.real(); | |
189 const T d = xi.imag(); | |
190 | |
191 val_real += (a*c) - (b*d); | |
192 val_imag += (a*d) + (b*c); | |
193 } | |
194 | |
195 return std::complex<T>(val_real, val_imag); | |
196 } | |
197 | |
198 }; | |
199 | |
200 | |
201 | |
202 //! \brief | |
203 //! Partial emulation of ATLAS/BLAS gemv(). | |
204 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) | |
205 | |
206 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
207 class gemv_emul_large | |
208 { | |
209 public: | |
210 | |
211 template<typename eT, typename TA> | |
212 arma_hot | |
213 inline | |
214 static | |
215 void | |
216 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
217 { | |
218 arma_extra_debug_sigprint(); | |
219 | |
220 const uword A_n_rows = A.n_rows; | |
221 const uword A_n_cols = A.n_cols; | |
222 | |
223 if(do_trans_A == false) | |
224 { | |
225 if(A_n_rows == 1) | |
226 { | |
227 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.memptr(), x); | |
228 | |
229 if( (use_alpha == false) && (use_beta == false) ) | |
230 { | |
231 y[0] = acc; | |
232 } | |
233 else | |
234 if( (use_alpha == true) && (use_beta == false) ) | |
235 { | |
236 y[0] = alpha * acc; | |
237 } | |
238 else | |
239 if( (use_alpha == false) && (use_beta == true) ) | |
240 { | |
241 y[0] = acc + beta*y[0]; | |
242 } | |
243 else | |
244 if( (use_alpha == true) && (use_beta == true) ) | |
245 { | |
246 y[0] = alpha*acc + beta*y[0]; | |
247 } | |
248 } | |
249 else | |
250 for(uword row=0; row < A_n_rows; ++row) | |
251 { | |
252 const eT acc = gemv_emul_large_helper::dot_row_col(A, x, row, A_n_cols); | |
253 | |
254 if( (use_alpha == false) && (use_beta == false) ) | |
255 { | |
256 y[row] = acc; | |
257 } | |
258 else | |
259 if( (use_alpha == true) && (use_beta == false) ) | |
260 { | |
261 y[row] = alpha * acc; | |
262 } | |
263 else | |
264 if( (use_alpha == false) && (use_beta == true) ) | |
265 { | |
266 y[row] = acc + beta*y[row]; | |
267 } | |
268 else | |
269 if( (use_alpha == true) && (use_beta == true) ) | |
270 { | |
271 y[row] = alpha*acc + beta*y[row]; | |
272 } | |
273 } | |
274 } | |
275 else | |
276 if(do_trans_A == true) | |
277 { | |
278 for(uword col=0; col < A_n_cols; ++col) | |
279 { | |
280 // col is interpreted as row when storing the results in 'y' | |
281 | |
282 | |
283 // const eT* A_coldata = A.colptr(col); | |
284 // | |
285 // eT acc = eT(0); | |
286 // for(uword row=0; row < A_n_rows; ++row) | |
287 // { | |
288 // acc += A_coldata[row] * x[row]; | |
289 // } | |
290 | |
291 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x); | |
292 | |
293 if( (use_alpha == false) && (use_beta == false) ) | |
294 { | |
295 y[col] = acc; | |
296 } | |
297 else | |
298 if( (use_alpha == true) && (use_beta == false) ) | |
299 { | |
300 y[col] = alpha * acc; | |
301 } | |
302 else | |
303 if( (use_alpha == false) && (use_beta == true) ) | |
304 { | |
305 y[col] = acc + beta*y[col]; | |
306 } | |
307 else | |
308 if( (use_alpha == true) && (use_beta == true) ) | |
309 { | |
310 y[col] = alpha*acc + beta*y[col]; | |
311 } | |
312 | |
313 } | |
314 } | |
315 } | |
316 | |
317 }; | |
318 | |
319 | |
320 | |
321 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
322 class gemv_emul | |
323 { | |
324 public: | |
325 | |
326 template<typename eT, typename TA> | |
327 arma_hot | |
328 inline | |
329 static | |
330 void | |
331 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 ) | |
332 { | |
333 arma_extra_debug_sigprint(); | |
334 arma_ignore(junk); | |
335 | |
336 const uword A_n_rows = A.n_rows; | |
337 const uword A_n_cols = A.n_cols; | |
338 | |
339 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) ) | |
340 { | |
341 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta); | |
342 } | |
343 else | |
344 { | |
345 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta); | |
346 } | |
347 } | |
348 | |
349 | |
350 | |
351 template<typename eT> | |
352 arma_hot | |
353 inline | |
354 static | |
355 void | |
356 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 ) | |
357 { | |
358 arma_extra_debug_sigprint(); | |
359 arma_ignore(junk); | |
360 | |
361 Mat<eT> tmp_A; | |
362 | |
363 if(do_trans_A) | |
364 { | |
365 op_htrans::apply_noalias(tmp_A, A); | |
366 } | |
367 | |
368 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; | |
369 | |
370 const uword AA_n_rows = AA.n_rows; | |
371 const uword AA_n_cols = AA.n_cols; | |
372 | |
373 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) ) | |
374 { | |
375 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta); | |
376 } | |
377 else | |
378 { | |
379 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta); | |
380 } | |
381 } | |
382 }; | |
383 | |
384 | |
385 | |
386 //! \brief | |
387 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv. | |
388 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) | |
389 | |
390 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
391 class gemv | |
392 { | |
393 public: | |
394 | |
395 template<typename eT, typename TA> | |
396 inline | |
397 static | |
398 void | |
399 apply_blas_type( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
400 { | |
401 arma_extra_debug_sigprint(); | |
402 | |
403 //const uword threshold = (is_complex<eT>::value == true) ? 16u : 64u; | |
404 const uword threshold = (is_complex<eT>::value == true) ? 64u : 100u; | |
405 | |
406 if(A.n_elem <= threshold) | |
407 { | |
408 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
409 } | |
410 else | |
411 { | |
412 #if defined(ARMA_USE_ATLAS) | |
413 { | |
414 if(is_complex<eT>::value == false) | |
415 { | |
416 // use gemm() instead of gemv() to work around a speed issue in Atlas 3.8.4 | |
417 | |
418 arma_extra_debug_print("atlas::cblas_gemm()"); | |
419 | |
420 atlas::cblas_gemm<eT> | |
421 ( | |
422 atlas::CblasColMajor, | |
423 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
424 atlas::CblasNoTrans, | |
425 (do_trans_A) ? A.n_cols : A.n_rows, | |
426 1, | |
427 (do_trans_A) ? A.n_rows : A.n_cols, | |
428 (use_alpha) ? alpha : eT(1), | |
429 A.mem, | |
430 A.n_rows, | |
431 x, | |
432 (do_trans_A) ? A.n_rows : A.n_cols, | |
433 (use_beta) ? beta : eT(0), | |
434 y, | |
435 (do_trans_A) ? A.n_cols : A.n_rows | |
436 ); | |
437 } | |
438 else | |
439 { | |
440 arma_extra_debug_print("atlas::cblas_gemv()"); | |
441 | |
442 atlas::cblas_gemv<eT> | |
443 ( | |
444 atlas::CblasColMajor, | |
445 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
446 A.n_rows, | |
447 A.n_cols, | |
448 (use_alpha) ? alpha : eT(1), | |
449 A.mem, | |
450 A.n_rows, | |
451 x, | |
452 1, | |
453 (use_beta) ? beta : eT(0), | |
454 y, | |
455 1 | |
456 ); | |
457 } | |
458 } | |
459 #elif defined(ARMA_USE_BLAS) | |
460 { | |
461 arma_extra_debug_print("blas::gemv()"); | |
462 | |
463 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
464 const blas_int m = A.n_rows; | |
465 const blas_int n = A.n_cols; | |
466 const eT local_alpha = (use_alpha) ? alpha : eT(1); | |
467 //const blas_int lda = A.n_rows; | |
468 const blas_int inc = 1; | |
469 const eT local_beta = (use_beta) ? beta : eT(0); | |
470 | |
471 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A ); | |
472 | |
473 blas::gemv<eT> | |
474 ( | |
475 &trans_A, | |
476 &m, | |
477 &n, | |
478 &local_alpha, | |
479 A.mem, | |
480 &m, // lda | |
481 x, | |
482 &inc, | |
483 &local_beta, | |
484 y, | |
485 &inc | |
486 ); | |
487 } | |
488 #else | |
489 { | |
490 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
491 } | |
492 #endif | |
493 } | |
494 | |
495 } | |
496 | |
497 | |
498 | |
499 template<typename eT, typename TA> | |
500 arma_inline | |
501 static | |
502 void | |
503 apply( eT* y, const TA& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
504 { | |
505 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
506 } | |
507 | |
508 | |
509 | |
510 template<typename TA> | |
511 arma_inline | |
512 static | |
513 void | |
514 apply | |
515 ( | |
516 float* y, | |
517 const TA& A, | |
518 const float* x, | |
519 const float alpha = float(1), | |
520 const float beta = float(0) | |
521 ) | |
522 { | |
523 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
524 } | |
525 | |
526 | |
527 | |
528 template<typename TA> | |
529 arma_inline | |
530 static | |
531 void | |
532 apply | |
533 ( | |
534 double* y, | |
535 const TA& A, | |
536 const double* x, | |
537 const double alpha = double(1), | |
538 const double beta = double(0) | |
539 ) | |
540 { | |
541 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
542 } | |
543 | |
544 | |
545 | |
546 template<typename TA> | |
547 arma_inline | |
548 static | |
549 void | |
550 apply | |
551 ( | |
552 std::complex<float>* y, | |
553 const TA& A, | |
554 const std::complex<float>* x, | |
555 const std::complex<float> alpha = std::complex<float>(1), | |
556 const std::complex<float> beta = std::complex<float>(0) | |
557 ) | |
558 { | |
559 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
560 } | |
561 | |
562 | |
563 | |
564 template<typename TA> | |
565 arma_inline | |
566 static | |
567 void | |
568 apply | |
569 ( | |
570 std::complex<double>* y, | |
571 const TA& A, | |
572 const std::complex<double>* x, | |
573 const std::complex<double> alpha = std::complex<double>(1), | |
574 const std::complex<double> beta = std::complex<double>(0) | |
575 ) | |
576 { | |
577 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
578 } | |
579 | |
580 | |
581 | |
582 }; | |
583 | |
584 | |
585 //! @} |