comparison armadillo-3.900.4/include/armadillo_bits/gemm.hpp @ 49:1ec0e2823891

Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author Chris Cannam
date Thu, 13 Jun 2013 10:25:24 +0100
parents
children
comparison
equal deleted inserted replaced
48:69251e11a913 49:1ec0e2823891
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2013 Conrad Sanderson
3 //
4 // This Source Code Form is subject to the terms of the Mozilla Public
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9 //! \addtogroup gemm
10 //! @{
11
12
13
14 //! for tiny square matrices, size <= 4x4
15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
16 class gemm_emul_tinysq
17 {
18 public:
19
20
21 template<typename eT, typename TA, typename TB>
22 arma_hot
23 inline
24 static
25 void
26 apply
27 (
28 Mat<eT>& C,
29 const TA& A,
30 const TB& B,
31 const eT alpha = eT(1),
32 const eT beta = eT(0)
33 )
34 {
35 arma_extra_debug_sigprint();
36
37 switch(A.n_rows)
38 {
39 case 4:
40 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
41
42 case 3:
43 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
44
45 case 2:
46 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
47
48 case 1:
49 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
50
51 default:
52 ;
53 }
54 }
55
56 };
57
58
59
60 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
61 class gemm_emul_large
62 {
63 public:
64
65 template<typename eT, typename TA, typename TB>
66 arma_hot
67 inline
68 static
69 void
70 apply
71 (
72 Mat<eT>& C,
73 const TA& A,
74 const TB& B,
75 const eT alpha = eT(1),
76 const eT beta = eT(0)
77 )
78 {
79 arma_extra_debug_sigprint();
80
81 const uword A_n_rows = A.n_rows;
82 const uword A_n_cols = A.n_cols;
83
84 const uword B_n_rows = B.n_rows;
85 const uword B_n_cols = B.n_cols;
86
87 if( (do_trans_A == false) && (do_trans_B == false) )
88 {
89 arma_aligned podarray<eT> tmp(A_n_cols);
90
91 eT* A_rowdata = tmp.memptr();
92
93 for(uword row_A=0; row_A < A_n_rows; ++row_A)
94 {
95 //tmp.copy_row(A, row_A);
96 const eT acc0 = op_dot::dot_and_copy_row(A_rowdata, A, row_A, B.colptr(0), A_n_cols);
97
98 if( (use_alpha == false) && (use_beta == false) )
99 {
100 C.at(row_A,0) = acc0;
101 }
102 else
103 if( (use_alpha == true) && (use_beta == false) )
104 {
105 C.at(row_A,0) = alpha * acc0;
106 }
107 else
108 if( (use_alpha == false) && (use_beta == true) )
109 {
110 C.at(row_A,0) = acc0 + beta*C.at(row_A,0);
111 }
112 else
113 if( (use_alpha == true) && (use_beta == true) )
114 {
115 C.at(row_A,0) = alpha*acc0 + beta*C.at(row_A,0);
116 }
117
118 //for(uword col_B=0; col_B < B_n_cols; ++col_B)
119 for(uword col_B=1; col_B < B_n_cols; ++col_B)
120 {
121 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
122
123 if( (use_alpha == false) && (use_beta == false) )
124 {
125 C.at(row_A,col_B) = acc;
126 }
127 else
128 if( (use_alpha == true) && (use_beta == false) )
129 {
130 C.at(row_A,col_B) = alpha * acc;
131 }
132 else
133 if( (use_alpha == false) && (use_beta == true) )
134 {
135 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
136 }
137 else
138 if( (use_alpha == true) && (use_beta == true) )
139 {
140 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
141 }
142
143 }
144 }
145 }
146 else
147 if( (do_trans_A == true) && (do_trans_B == false) )
148 {
149 for(uword col_A=0; col_A < A_n_cols; ++col_A)
150 {
151 // col_A is interpreted as row_A when storing the results in matrix C
152
153 const eT* A_coldata = A.colptr(col_A);
154
155 for(uword col_B=0; col_B < B_n_cols; ++col_B)
156 {
157 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
158
159 if( (use_alpha == false) && (use_beta == false) )
160 {
161 C.at(col_A,col_B) = acc;
162 }
163 else
164 if( (use_alpha == true) && (use_beta == false) )
165 {
166 C.at(col_A,col_B) = alpha * acc;
167 }
168 else
169 if( (use_alpha == false) && (use_beta == true) )
170 {
171 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
172 }
173 else
174 if( (use_alpha == true) && (use_beta == true) )
175 {
176 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
177 }
178
179 }
180 }
181 }
182 else
183 if( (do_trans_A == false) && (do_trans_B == true) )
184 {
185 Mat<eT> BB;
186 op_strans::apply_noalias(BB, B);
187
188 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
189 }
190 else
191 if( (do_trans_A == true) && (do_trans_B == true) )
192 {
193 // mat B_tmp = trans(B);
194 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
195
196
197 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
198 // transpose operations are not needed
199
200 arma_aligned podarray<eT> tmp(B.n_cols);
201 eT* B_rowdata = tmp.memptr();
202
203 for(uword row_B=0; row_B < B_n_rows; ++row_B)
204 {
205 tmp.copy_row(B, row_B);
206
207 for(uword col_A=0; col_A < A_n_cols; ++col_A)
208 {
209 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
210
211 if( (use_alpha == false) && (use_beta == false) )
212 {
213 C.at(col_A,row_B) = acc;
214 }
215 else
216 if( (use_alpha == true) && (use_beta == false) )
217 {
218 C.at(col_A,row_B) = alpha * acc;
219 }
220 else
221 if( (use_alpha == false) && (use_beta == true) )
222 {
223 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
224 }
225 else
226 if( (use_alpha == true) && (use_beta == true) )
227 {
228 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
229 }
230
231 }
232 }
233
234 }
235 }
236
237 };
238
239
240
241 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
242 class gemm_emul
243 {
244 public:
245
246
247 template<typename eT, typename TA, typename TB>
248 arma_hot
249 inline
250 static
251 void
252 apply
253 (
254 Mat<eT>& C,
255 const TA& A,
256 const TB& B,
257 const eT alpha = eT(1),
258 const eT beta = eT(0),
259 const typename arma_not_cx<eT>::result* junk = 0
260 )
261 {
262 arma_extra_debug_sigprint();
263 arma_ignore(junk);
264
265 const uword A_n_rows = A.n_rows;
266 const uword A_n_cols = A.n_cols;
267
268 const uword B_n_rows = B.n_rows;
269 const uword B_n_cols = B.n_cols;
270
271 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
272 {
273 if(do_trans_B == false)
274 {
275 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
276 }
277 else
278 {
279 Mat<eT> BB(A_n_rows, A_n_rows);
280 op_strans::apply_noalias_tinysq(BB, B);
281
282 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
283 }
284 }
285 else
286 {
287 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
288 }
289 }
290
291
292
293 template<typename eT>
294 arma_hot
295 inline
296 static
297 void
298 apply
299 (
300 Mat<eT>& C,
301 const Mat<eT>& A,
302 const Mat<eT>& B,
303 const eT alpha = eT(1),
304 const eT beta = eT(0),
305 const typename arma_cx_only<eT>::result* junk = 0
306 )
307 {
308 arma_extra_debug_sigprint();
309 arma_ignore(junk);
310
311 // "better than nothing" handling of hermitian transposes for complex number matrices
312
313 Mat<eT> tmp_A;
314 Mat<eT> tmp_B;
315
316 if(do_trans_A)
317 {
318 op_htrans::apply_noalias(tmp_A, A);
319 }
320
321 if(do_trans_B)
322 {
323 op_htrans::apply_noalias(tmp_B, B);
324 }
325
326 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
327 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
328
329 const uword A_n_rows = AA.n_rows;
330 const uword A_n_cols = AA.n_cols;
331
332 const uword B_n_rows = BB.n_rows;
333 const uword B_n_cols = BB.n_cols;
334
335 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
336 {
337 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
338 }
339 else
340 {
341 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
342 }
343 }
344
345 };
346
347
348
349 //! \brief
350 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
351 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
352
353 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
354 class gemm
355 {
356 public:
357
358 template<typename eT, typename TA, typename TB>
359 inline
360 static
361 void
362 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
363 {
364 arma_extra_debug_sigprint();
365
366 const uword threshold = (is_Mat_fixed<TA>::value && is_Mat_fixed<TB>::value)
367 ? (is_complex<eT>::value ? 16u : 64u)
368 : (is_complex<eT>::value ? 16u : 48u);
369
370 if( (A.n_elem <= threshold) && (B.n_elem <= threshold) )
371 {
372 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
373 }
374 else
375 {
376 #if defined(ARMA_USE_ATLAS)
377 {
378 arma_extra_debug_print("atlas::cblas_gemm()");
379
380 atlas::cblas_gemm<eT>
381 (
382 atlas::CblasColMajor,
383 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
384 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
385 C.n_rows,
386 C.n_cols,
387 (do_trans_A) ? A.n_rows : A.n_cols,
388 (use_alpha) ? alpha : eT(1),
389 A.mem,
390 (do_trans_A) ? A.n_rows : C.n_rows,
391 B.mem,
392 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
393 (use_beta) ? beta : eT(0),
394 C.memptr(),
395 C.n_rows
396 );
397 }
398 #elif defined(ARMA_USE_BLAS)
399 {
400 arma_extra_debug_print("blas::gemm()");
401
402 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
403 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
404
405 const blas_int m = C.n_rows;
406 const blas_int n = C.n_cols;
407 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
408
409 const eT local_alpha = (use_alpha) ? alpha : eT(1);
410
411 const blas_int lda = (do_trans_A) ? k : m;
412 const blas_int ldb = (do_trans_B) ? n : k;
413
414 const eT local_beta = (use_beta) ? beta : eT(0);
415
416 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
417 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
418
419 blas::gemm<eT>
420 (
421 &trans_A,
422 &trans_B,
423 &m,
424 &n,
425 &k,
426 &local_alpha,
427 A.mem,
428 &lda,
429 B.mem,
430 &ldb,
431 &local_beta,
432 C.memptr(),
433 &m
434 );
435 }
436 #else
437 {
438 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
439 }
440 #endif
441 }
442 }
443
444
445
446 //! immediate multiplication of matrices A and B, storing the result in C
447 template<typename eT, typename TA, typename TB>
448 inline
449 static
450 void
451 apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
452 {
453 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
454 }
455
456
457
458 template<typename TA, typename TB>
459 arma_inline
460 static
461 void
462 apply
463 (
464 Mat<float>& C,
465 const TA& A,
466 const TB& B,
467 const float alpha = float(1),
468 const float beta = float(0)
469 )
470 {
471 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
472 }
473
474
475
476 template<typename TA, typename TB>
477 arma_inline
478 static
479 void
480 apply
481 (
482 Mat<double>& C,
483 const TA& A,
484 const TB& B,
485 const double alpha = double(1),
486 const double beta = double(0)
487 )
488 {
489 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
490 }
491
492
493
494 template<typename TA, typename TB>
495 arma_inline
496 static
497 void
498 apply
499 (
500 Mat< std::complex<float> >& C,
501 const TA& A,
502 const TB& B,
503 const std::complex<float> alpha = std::complex<float>(1),
504 const std::complex<float> beta = std::complex<float>(0)
505 )
506 {
507 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
508 }
509
510
511
512 template<typename TA, typename TB>
513 arma_inline
514 static
515 void
516 apply
517 (
518 Mat< std::complex<double> >& C,
519 const TA& A,
520 const TB& B,
521 const std::complex<double> alpha = std::complex<double>(1),
522 const std::complex<double> beta = std::complex<double>(0)
523 )
524 {
525 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
526 }
527
528 };
529
530
531
532 //! @}