Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/gemm.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup gemm | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 //! for tiny square matrices, size <= 4x4 | |
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
21 class gemm_emul_tinysq | |
22 { | |
23 public: | |
24 | |
25 | |
26 template<typename eT> | |
27 arma_hot | |
28 inline | |
29 static | |
30 void | |
31 apply | |
32 ( | |
33 Mat<eT>& C, | |
34 const Mat<eT>& A, | |
35 const Mat<eT>& B, | |
36 const eT alpha = eT(1), | |
37 const eT beta = eT(0) | |
38 ) | |
39 { | |
40 arma_extra_debug_sigprint(); | |
41 | |
42 switch(A.n_rows) | |
43 { | |
44 case 4: | |
45 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta ); | |
46 | |
47 case 3: | |
48 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta ); | |
49 | |
50 case 2: | |
51 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta ); | |
52 | |
53 case 1: | |
54 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta ); | |
55 | |
56 default: | |
57 ; | |
58 } | |
59 } | |
60 | |
61 }; | |
62 | |
63 | |
64 | |
65 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
66 class gemm_emul_large | |
67 { | |
68 public: | |
69 | |
70 template<typename eT> | |
71 arma_hot | |
72 inline | |
73 static | |
74 void | |
75 apply | |
76 ( | |
77 Mat<eT>& C, | |
78 const Mat<eT>& A, | |
79 const Mat<eT>& B, | |
80 const eT alpha = eT(1), | |
81 const eT beta = eT(0) | |
82 ) | |
83 { | |
84 arma_extra_debug_sigprint(); | |
85 | |
86 const uword A_n_rows = A.n_rows; | |
87 const uword A_n_cols = A.n_cols; | |
88 | |
89 const uword B_n_rows = B.n_rows; | |
90 const uword B_n_cols = B.n_cols; | |
91 | |
92 if( (do_trans_A == false) && (do_trans_B == false) ) | |
93 { | |
94 arma_aligned podarray<eT> tmp(A_n_cols); | |
95 eT* A_rowdata = tmp.memptr(); | |
96 | |
97 for(uword row_A=0; row_A < A_n_rows; ++row_A) | |
98 { | |
99 tmp.copy_row(A, row_A); | |
100 | |
101 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
102 { | |
103 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B)); | |
104 | |
105 if( (use_alpha == false) && (use_beta == false) ) | |
106 { | |
107 C.at(row_A,col_B) = acc; | |
108 } | |
109 else | |
110 if( (use_alpha == true) && (use_beta == false) ) | |
111 { | |
112 C.at(row_A,col_B) = alpha * acc; | |
113 } | |
114 else | |
115 if( (use_alpha == false) && (use_beta == true) ) | |
116 { | |
117 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
118 } | |
119 else | |
120 if( (use_alpha == true) && (use_beta == true) ) | |
121 { | |
122 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
123 } | |
124 | |
125 } | |
126 } | |
127 } | |
128 else | |
129 if( (do_trans_A == true) && (do_trans_B == false) ) | |
130 { | |
131 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
132 { | |
133 // col_A is interpreted as row_A when storing the results in matrix C | |
134 | |
135 const eT* A_coldata = A.colptr(col_A); | |
136 | |
137 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
138 { | |
139 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B)); | |
140 | |
141 if( (use_alpha == false) && (use_beta == false) ) | |
142 { | |
143 C.at(col_A,col_B) = acc; | |
144 } | |
145 else | |
146 if( (use_alpha == true) && (use_beta == false) ) | |
147 { | |
148 C.at(col_A,col_B) = alpha * acc; | |
149 } | |
150 else | |
151 if( (use_alpha == false) && (use_beta == true) ) | |
152 { | |
153 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
154 } | |
155 else | |
156 if( (use_alpha == true) && (use_beta == true) ) | |
157 { | |
158 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
159 } | |
160 | |
161 } | |
162 } | |
163 } | |
164 else | |
165 if( (do_trans_A == false) && (do_trans_B == true) ) | |
166 { | |
167 Mat<eT> BB; | |
168 op_strans::apply_noalias(BB, B); | |
169 | |
170 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); | |
171 } | |
172 else | |
173 if( (do_trans_A == true) && (do_trans_B == true) ) | |
174 { | |
175 // mat B_tmp = trans(B); | |
176 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
177 | |
178 | |
179 // By using the trans(A)*trans(B) = trans(B*A) equivalency, | |
180 // transpose operations are not needed | |
181 | |
182 arma_aligned podarray<eT> tmp(B.n_cols); | |
183 eT* B_rowdata = tmp.memptr(); | |
184 | |
185 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
186 { | |
187 tmp.copy_row(B, row_B); | |
188 | |
189 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
190 { | |
191 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A)); | |
192 | |
193 if( (use_alpha == false) && (use_beta == false) ) | |
194 { | |
195 C.at(col_A,row_B) = acc; | |
196 } | |
197 else | |
198 if( (use_alpha == true) && (use_beta == false) ) | |
199 { | |
200 C.at(col_A,row_B) = alpha * acc; | |
201 } | |
202 else | |
203 if( (use_alpha == false) && (use_beta == true) ) | |
204 { | |
205 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
206 } | |
207 else | |
208 if( (use_alpha == true) && (use_beta == true) ) | |
209 { | |
210 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
211 } | |
212 | |
213 } | |
214 } | |
215 | |
216 } | |
217 } | |
218 | |
219 }; | |
220 | |
221 | |
222 | |
223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
224 class gemm_emul | |
225 { | |
226 public: | |
227 | |
228 | |
229 template<typename eT> | |
230 arma_hot | |
231 inline | |
232 static | |
233 void | |
234 apply | |
235 ( | |
236 Mat<eT>& C, | |
237 const Mat<eT>& A, | |
238 const Mat<eT>& B, | |
239 const eT alpha = eT(1), | |
240 const eT beta = eT(0), | |
241 const typename arma_not_cx<eT>::result* junk = 0 | |
242 ) | |
243 { | |
244 arma_extra_debug_sigprint(); | |
245 arma_ignore(junk); | |
246 | |
247 const uword A_n_rows = A.n_rows; | |
248 const uword A_n_cols = A.n_cols; | |
249 | |
250 const uword B_n_rows = B.n_rows; | |
251 const uword B_n_cols = B.n_cols; | |
252 | |
253 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) | |
254 { | |
255 if(do_trans_B == false) | |
256 { | |
257 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta); | |
258 } | |
259 else | |
260 { | |
261 Mat<eT> BB(A_n_rows, A_n_rows); | |
262 op_strans::apply_noalias_tinysq(BB, B); | |
263 | |
264 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta); | |
265 } | |
266 } | |
267 else | |
268 { | |
269 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta); | |
270 } | |
271 } | |
272 | |
273 | |
274 | |
275 template<typename eT> | |
276 arma_hot | |
277 inline | |
278 static | |
279 void | |
280 apply | |
281 ( | |
282 Mat<eT>& C, | |
283 const Mat<eT>& A, | |
284 const Mat<eT>& B, | |
285 const eT alpha = eT(1), | |
286 const eT beta = eT(0), | |
287 const typename arma_cx_only<eT>::result* junk = 0 | |
288 ) | |
289 { | |
290 arma_extra_debug_sigprint(); | |
291 arma_ignore(junk); | |
292 | |
293 // "better than nothing" handling of hermitian transposes for complex number matrices | |
294 | |
295 Mat<eT> tmp_A; | |
296 Mat<eT> tmp_B; | |
297 | |
298 if(do_trans_A) | |
299 { | |
300 op_htrans::apply_noalias(tmp_A, A); | |
301 } | |
302 | |
303 if(do_trans_B) | |
304 { | |
305 op_htrans::apply_noalias(tmp_B, B); | |
306 } | |
307 | |
308 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; | |
309 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B; | |
310 | |
311 const uword A_n_rows = AA.n_rows; | |
312 const uword A_n_cols = AA.n_cols; | |
313 | |
314 const uword B_n_rows = BB.n_rows; | |
315 const uword B_n_cols = BB.n_cols; | |
316 | |
317 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) ) | |
318 { | |
319 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
320 } | |
321 else | |
322 { | |
323 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
324 } | |
325 } | |
326 | |
327 }; | |
328 | |
329 | |
330 | |
331 //! \brief | |
332 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm. | |
333 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
334 | |
335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
336 class gemm | |
337 { | |
338 public: | |
339 | |
340 template<typename eT> | |
341 inline | |
342 static | |
343 void | |
344 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) ) | |
345 { | |
346 arma_extra_debug_sigprint(); | |
347 | |
348 if( (A.n_elem <= 48u) && (B.n_elem <= 48u) ) | |
349 { | |
350 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
351 } | |
352 else | |
353 { | |
354 #if defined(ARMA_USE_ATLAS) | |
355 { | |
356 arma_extra_debug_print("atlas::cblas_gemm()"); | |
357 | |
358 atlas::cblas_gemm<eT> | |
359 ( | |
360 atlas::CblasColMajor, | |
361 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
362 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
363 C.n_rows, | |
364 C.n_cols, | |
365 (do_trans_A) ? A.n_rows : A.n_cols, | |
366 (use_alpha) ? alpha : eT(1), | |
367 A.mem, | |
368 (do_trans_A) ? A.n_rows : C.n_rows, | |
369 B.mem, | |
370 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ), | |
371 (use_beta) ? beta : eT(0), | |
372 C.memptr(), | |
373 C.n_rows | |
374 ); | |
375 } | |
376 #elif defined(ARMA_USE_BLAS) | |
377 { | |
378 arma_extra_debug_print("blas::gemm()"); | |
379 | |
380 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
381 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
382 | |
383 const blas_int m = C.n_rows; | |
384 const blas_int n = C.n_cols; | |
385 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols; | |
386 | |
387 const eT local_alpha = (use_alpha) ? alpha : eT(1); | |
388 | |
389 const blas_int lda = (do_trans_A) ? k : m; | |
390 const blas_int ldb = (do_trans_B) ? n : k; | |
391 | |
392 const eT local_beta = (use_beta) ? beta : eT(0); | |
393 | |
394 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A ); | |
395 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B ); | |
396 | |
397 blas::gemm<eT> | |
398 ( | |
399 &trans_A, | |
400 &trans_B, | |
401 &m, | |
402 &n, | |
403 &k, | |
404 &local_alpha, | |
405 A.mem, | |
406 &lda, | |
407 B.mem, | |
408 &ldb, | |
409 &local_beta, | |
410 C.memptr(), | |
411 &m | |
412 ); | |
413 } | |
414 #else | |
415 { | |
416 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
417 } | |
418 #endif | |
419 } | |
420 } | |
421 | |
422 | |
423 | |
424 //! immediate multiplication of matrices A and B, storing the result in C | |
425 template<typename eT> | |
426 inline | |
427 static | |
428 void | |
429 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) ) | |
430 { | |
431 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta); | |
432 } | |
433 | |
434 | |
435 | |
436 arma_inline | |
437 static | |
438 void | |
439 apply | |
440 ( | |
441 Mat<float>& C, | |
442 const Mat<float>& A, | |
443 const Mat<float>& B, | |
444 const float alpha = float(1), | |
445 const float beta = float(0) | |
446 ) | |
447 { | |
448 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
449 } | |
450 | |
451 | |
452 | |
453 arma_inline | |
454 static | |
455 void | |
456 apply | |
457 ( | |
458 Mat<double>& C, | |
459 const Mat<double>& A, | |
460 const Mat<double>& B, | |
461 const double alpha = double(1), | |
462 const double beta = double(0) | |
463 ) | |
464 { | |
465 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
466 } | |
467 | |
468 | |
469 | |
470 arma_inline | |
471 static | |
472 void | |
473 apply | |
474 ( | |
475 Mat< std::complex<float> >& C, | |
476 const Mat< std::complex<float> >& A, | |
477 const Mat< std::complex<float> >& B, | |
478 const std::complex<float> alpha = std::complex<float>(1), | |
479 const std::complex<float> beta = std::complex<float>(0) | |
480 ) | |
481 { | |
482 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
483 } | |
484 | |
485 | |
486 | |
487 arma_inline | |
488 static | |
489 void | |
490 apply | |
491 ( | |
492 Mat< std::complex<double> >& C, | |
493 const Mat< std::complex<double> >& A, | |
494 const Mat< std::complex<double> >& B, | |
495 const std::complex<double> alpha = std::complex<double>(1), | |
496 const std::complex<double> beta = std::complex<double>(0) | |
497 ) | |
498 { | |
499 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta); | |
500 } | |
501 | |
502 }; | |
503 | |
504 | |
505 | |
506 //! @} |