Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-3.900.4/include/armadillo_bits/gemm.hpp @ 49:1ec0e2823891
Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author | Chris Cannam |
---|---|
date | Thu, 13 Jun 2013 10:25:24 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
48:69251e11a913 | 49:1ec0e2823891 |
---|---|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2013 Conrad Sanderson | |
3 // | |
4 // This Source Code Form is subject to the terms of the Mozilla Public | |
5 // License, v. 2.0. If a copy of the MPL was not distributed with this | |
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |
7 | |
8 | |
9 //! \addtogroup gemm | |
10 //! @{ | |
11 | |
12 | |
13 | |
14 //! for tiny square matrices, size <= 4x4 | |
15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
16 class gemm_emul_tinysq | |
17 { | |
18 public: | |
19 | |
20 | |
21 template<typename eT, typename TA, typename TB> | |
22 arma_hot | |
23 inline | |
24 static | |
25 void | |
26 apply | |
27 ( | |
28 Mat<eT>& C, | |
29 const TA& A, | |
30 const TB& B, | |
31 const eT alpha = eT(1), | |
32 const eT beta = eT(0) | |
33 ) | |
34 { | |
35 arma_extra_debug_sigprint(); | |
36 | |
37 switch(A.n_rows) | |
38 { | |
39 case 4: | |
40 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta ); | |
41 | |
42 case 3: | |
43 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta ); | |
44 | |
45 case 2: | |
46 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta ); | |
47 | |
48 case 1: | |
49 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta ); | |
50 | |
51 default: | |
52 ; | |
53 } | |
54 } | |
55 | |
56 }; | |
57 | |
58 | |
59 | |
60 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
61 class gemm_emul_large | |
62 { | |
63 public: | |
64 | |
65 template<typename eT, typename TA, typename TB> | |
66 arma_hot | |
67 inline | |
68 static | |
69 void | |
70 apply | |
71 ( | |
72 Mat<eT>& C, | |
73 const TA& A, | |
74 const TB& B, | |
75 const eT alpha = eT(1), | |
76 const eT beta = eT(0) | |
77 ) | |
78 { | |
79 arma_extra_debug_sigprint(); | |
80 | |
81 const uword A_n_rows = A.n_rows; | |
82 const uword A_n_cols = A.n_cols; | |
83 | |
84 const uword B_n_rows = B.n_rows; | |
85 const uword B_n_cols = B.n_cols; | |
86 | |
87 if( (do_trans_A == false) && (do_trans_B == false) ) | |
88 { | |
89 arma_aligned podarray<eT> tmp(A_n_cols); | |
90 | |
91 eT* A_rowdata = tmp.memptr(); | |
92 | |
93 for(uword row_A=0; row_A < A_n_rows; ++row_A) | |
94 { | |
95 //tmp.copy_row(A, row_A); | |
96 const eT acc0 = op_dot::dot_and_copy_row(A_rowdata, A, row_A, B.colptr(0), A_n_cols); | |
97 | |
98 if( (use_alpha == false) && (use_beta == false) ) | |
99 { | |
100 C.at(row_A,0) = acc0; | |
101 } | |
102 else | |
103 if( (use_alpha == true) && (use_beta == false) ) | |
104 { | |
105 C.at(row_A,0) = alpha * acc0; | |
106 } | |
107 else | |
108 if( (use_alpha == false) && (use_beta == true) ) | |
109 { | |
110 C.at(row_A,0) = acc0 + beta*C.at(row_A,0); | |
111 } | |
112 else | |
113 if( (use_alpha == true) && (use_beta == true) ) | |
114 { | |
115 C.at(row_A,0) = alpha*acc0 + beta*C.at(row_A,0); | |
116 } | |
117 | |
118 //for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
119 for(uword col_B=1; col_B < B_n_cols; ++col_B) | |
120 { | |
121 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B)); | |
122 | |
123 if( (use_alpha == false) && (use_beta == false) ) | |
124 { | |
125 C.at(row_A,col_B) = acc; | |
126 } | |
127 else | |
128 if( (use_alpha == true) && (use_beta == false) ) | |
129 { | |
130 C.at(row_A,col_B) = alpha * acc; | |
131 } | |
132 else | |
133 if( (use_alpha == false) && (use_beta == true) ) | |
134 { | |
135 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
136 } | |
137 else | |
138 if( (use_alpha == true) && (use_beta == true) ) | |
139 { | |
140 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
141 } | |
142 | |
143 } | |
144 } | |
145 } | |
146 else | |
147 if( (do_trans_A == true) && (do_trans_B == false) ) | |
148 { | |
149 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
150 { | |
151 // col_A is interpreted as row_A when storing the results in matrix C | |
152 | |
153 const eT* A_coldata = A.colptr(col_A); | |
154 | |
155 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
156 { | |
157 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B)); | |
158 | |
159 if( (use_alpha == false) && (use_beta == false) ) | |
160 { | |
161 C.at(col_A,col_B) = acc; | |
162 } | |
163 else | |
164 if( (use_alpha == true) && (use_beta == false) ) | |
165 { | |
166 C.at(col_A,col_B) = alpha * acc; | |
167 } | |
168 else | |
169 if( (use_alpha == false) && (use_beta == true) ) | |
170 { | |
171 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
172 } | |
173 else | |
174 if( (use_alpha == true) && (use_beta == true) ) | |
175 { | |
176 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
177 } | |
178 | |
179 } | |
180 } | |
181 } | |
182 else | |
183 if( (do_trans_A == false) && (do_trans_B == true) ) | |
184 { | |
185 Mat<eT> BB; | |
186 op_strans::apply_noalias(BB, B); | |
187 | |
188 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); | |
189 } | |
190 else | |
191 if( (do_trans_A == true) && (do_trans_B == true) ) | |
192 { | |
193 // mat B_tmp = trans(B); | |
194 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
195 | |
196 | |
197 // By using the trans(A)*trans(B) = trans(B*A) equivalency, | |
198 // transpose operations are not needed | |
199 | |
200 arma_aligned podarray<eT> tmp(B.n_cols); | |
201 eT* B_rowdata = tmp.memptr(); | |
202 | |
203 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
204 { | |
205 tmp.copy_row(B, row_B); | |
206 | |
207 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
208 { | |
209 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A)); | |
210 | |
211 if( (use_alpha == false) && (use_beta == false) ) | |
212 { | |
213 C.at(col_A,row_B) = acc; | |
214 } | |
215 else | |
216 if( (use_alpha == true) && (use_beta == false) ) | |
217 { | |
218 C.at(col_A,row_B) = alpha * acc; | |
219 } | |
220 else | |
221 if( (use_alpha == false) && (use_beta == true) ) | |
222 { | |
223 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
224 } | |
225 else | |
226 if( (use_alpha == true) && (use_beta == true) ) | |
227 { | |
228 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
229 } | |
230 | |
231 } | |
232 } | |
233 | |
234 } | |
235 } | |
236 | |
237 }; | |
238 | |
239 | |
240 | |
241 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
242 class gemm_emul | |
243 { | |
244 public: | |
245 | |
246 | |
247 template<typename eT, typename TA, typename TB> | |
248 arma_hot | |
249 inline | |
250 static | |
251 void | |
252 apply | |
253 ( | |
254 Mat<eT>& C, | |
255 const TA& A, | |
256 const TB& B, | |
257 const eT alpha = eT(1), | |
258 const eT beta = eT(0), | |
259 const typename arma_not_cx<eT>::result* junk = 0 | |
260 ) | |
261 { | |
262 arma_extra_debug_sigprint(); | |
263 arma_ignore(junk); | |
264 | |
265 const uword A_n_rows = A.n_rows; | |
266 const uword A_n_cols = A.n_cols; | |
267 | |
268 const uword B_n_rows = B.n_rows; | |
269 const uword B_n_cols = B.n_cols; | |
270 | |
271 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) | |
272 { | |
273 if(do_trans_B == false) | |
274 { | |
275 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta); | |
276 } | |
277 else | |
278 { | |
279 Mat<eT> BB(A_n_rows, A_n_rows); | |
280 op_strans::apply_noalias_tinysq(BB, B); | |
281 | |
282 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); | |
283 } | |
284 } | |
285 else | |
286 { | |
287 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta); | |
288 } | |
289 } | |
290 | |
291 | |
292 | |
293 template<typename eT> | |
294 arma_hot | |
295 inline | |
296 static | |
297 void | |
298 apply | |
299 ( | |
300 Mat<eT>& C, | |
301 const Mat<eT>& A, | |
302 const Mat<eT>& B, | |
303 const eT alpha = eT(1), | |
304 const eT beta = eT(0), | |
305 const typename arma_cx_only<eT>::result* junk = 0 | |
306 ) | |
307 { | |
308 arma_extra_debug_sigprint(); | |
309 arma_ignore(junk); | |
310 | |
311 // "better than nothing" handling of hermitian transposes for complex number matrices | |
312 | |
313 Mat<eT> tmp_A; | |
314 Mat<eT> tmp_B; | |
315 | |
316 if(do_trans_A) | |
317 { | |
318 op_htrans::apply_noalias(tmp_A, A); | |
319 } | |
320 | |
321 if(do_trans_B) | |
322 { | |
323 op_htrans::apply_noalias(tmp_B, B); | |
324 } | |
325 | |
326 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; | |
327 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B; | |
328 | |
329 const uword A_n_rows = AA.n_rows; | |
330 const uword A_n_cols = AA.n_cols; | |
331 | |
332 const uword B_n_rows = BB.n_rows; | |
333 const uword B_n_cols = BB.n_cols; | |
334 | |
335 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) | |
336 { | |
337 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
338 } | |
339 else | |
340 { | |
341 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
342 } | |
343 } | |
344 | |
345 }; | |
346 | |
347 | |
348 | |
349 //! \brief | |
350 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm. | |
351 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
352 | |
353 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
354 class gemm | |
355 { | |
356 public: | |
357 | |
358 template<typename eT, typename TA, typename TB> | |
359 inline | |
360 static | |
361 void | |
362 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) | |
363 { | |
364 arma_extra_debug_sigprint(); | |
365 | |
366 const uword threshold = (is_Mat_fixed<TA>::value && is_Mat_fixed<TB>::value) | |
367 ? (is_complex<eT>::value ? 16u : 64u) | |
368 : (is_complex<eT>::value ? 16u : 48u); | |
369 | |
370 if( (A.n_elem <= threshold) && (B.n_elem <= threshold) ) | |
371 { | |
372 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
373 } | |
374 else | |
375 { | |
376 #if defined(ARMA_USE_ATLAS) | |
377 { | |
378 arma_extra_debug_print("atlas::cblas_gemm()"); | |
379 | |
380 atlas::cblas_gemm<eT> | |
381 ( | |
382 atlas::CblasColMajor, | |
383 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
384 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
385 C.n_rows, | |
386 C.n_cols, | |
387 (do_trans_A) ? A.n_rows : A.n_cols, | |
388 (use_alpha) ? alpha : eT(1), | |
389 A.mem, | |
390 (do_trans_A) ? A.n_rows : C.n_rows, | |
391 B.mem, | |
392 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ), | |
393 (use_beta) ? beta : eT(0), | |
394 C.memptr(), | |
395 C.n_rows | |
396 ); | |
397 } | |
398 #elif defined(ARMA_USE_BLAS) | |
399 { | |
400 arma_extra_debug_print("blas::gemm()"); | |
401 | |
402 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
403 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
404 | |
405 const blas_int m = C.n_rows; | |
406 const blas_int n = C.n_cols; | |
407 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols; | |
408 | |
409 const eT local_alpha = (use_alpha) ? alpha : eT(1); | |
410 | |
411 const blas_int lda = (do_trans_A) ? k : m; | |
412 const blas_int ldb = (do_trans_B) ? n : k; | |
413 | |
414 const eT local_beta = (use_beta) ? beta : eT(0); | |
415 | |
416 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A ); | |
417 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B ); | |
418 | |
419 blas::gemm<eT> | |
420 ( | |
421 &trans_A, | |
422 &trans_B, | |
423 &m, | |
424 &n, | |
425 &k, | |
426 &local_alpha, | |
427 A.mem, | |
428 &lda, | |
429 B.mem, | |
430 &ldb, | |
431 &local_beta, | |
432 C.memptr(), | |
433 &m | |
434 ); | |
435 } | |
436 #else | |
437 { | |
438 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
439 } | |
440 #endif | |
441 } | |
442 } | |
443 | |
444 | |
445 | |
446 //! immediate multiplication of matrices A and B, storing the result in C | |
447 template<typename eT, typename TA, typename TB> | |
448 inline | |
449 static | |
450 void | |
451 apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) ) | |
452 { | |
453 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
454 } | |
455 | |
456 | |
457 | |
458 template<typename TA, typename TB> | |
459 arma_inline | |
460 static | |
461 void | |
462 apply | |
463 ( | |
464 Mat<float>& C, | |
465 const TA& A, | |
466 const TB& B, | |
467 const float alpha = float(1), | |
468 const float beta = float(0) | |
469 ) | |
470 { | |
471 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
472 } | |
473 | |
474 | |
475 | |
476 template<typename TA, typename TB> | |
477 arma_inline | |
478 static | |
479 void | |
480 apply | |
481 ( | |
482 Mat<double>& C, | |
483 const TA& A, | |
484 const TB& B, | |
485 const double alpha = double(1), | |
486 const double beta = double(0) | |
487 ) | |
488 { | |
489 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
490 } | |
491 | |
492 | |
493 | |
494 template<typename TA, typename TB> | |
495 arma_inline | |
496 static | |
497 void | |
498 apply | |
499 ( | |
500 Mat< std::complex<float> >& C, | |
501 const TA& A, | |
502 const TB& B, | |
503 const std::complex<float> alpha = std::complex<float>(1), | |
504 const std::complex<float> beta = std::complex<float>(0) | |
505 ) | |
506 { | |
507 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
508 } | |
509 | |
510 | |
511 | |
512 template<typename TA, typename TB> | |
513 arma_inline | |
514 static | |
515 void | |
516 apply | |
517 ( | |
518 Mat< std::complex<double> >& C, | |
519 const TA& A, | |
520 const TB& B, | |
521 const std::complex<double> alpha = std::complex<double>(1), | |
522 const std::complex<double> beta = std::complex<double>(0) | |
523 ) | |
524 { | |
525 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
526 } | |
527 | |
528 }; | |
529 | |
530 | |
531 | |
532 //! @} |