comparison armadillo-2.4.4/include/armadillo_bits/gemm.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12
13
14 //! \addtogroup gemm
15 //! @{
16
17
18
19 //! for tiny square matrices, size <= 4x4
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
21 class gemm_emul_tinysq
22 {
23 public:
24
25
26 template<typename eT>
27 arma_hot
28 inline
29 static
30 void
31 apply
32 (
33 Mat<eT>& C,
34 const Mat<eT>& A,
35 const Mat<eT>& B,
36 const eT alpha = eT(1),
37 const eT beta = eT(0)
38 )
39 {
40 arma_extra_debug_sigprint();
41
42 switch(A.n_rows)
43 {
44 case 4:
45 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
46
47 case 3:
48 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
49
50 case 2:
51 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
52
53 case 1:
54 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
55
56 default:
57 ;
58 }
59 }
60
61 };
62
63
64
65 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
66 class gemm_emul_large
67 {
68 public:
69
70 template<typename eT>
71 arma_hot
72 inline
73 static
74 void
75 apply
76 (
77 Mat<eT>& C,
78 const Mat<eT>& A,
79 const Mat<eT>& B,
80 const eT alpha = eT(1),
81 const eT beta = eT(0)
82 )
83 {
84 arma_extra_debug_sigprint();
85
86 const uword A_n_rows = A.n_rows;
87 const uword A_n_cols = A.n_cols;
88
89 const uword B_n_rows = B.n_rows;
90 const uword B_n_cols = B.n_cols;
91
92 if( (do_trans_A == false) && (do_trans_B == false) )
93 {
94 arma_aligned podarray<eT> tmp(A_n_cols);
95 eT* A_rowdata = tmp.memptr();
96
97 for(uword row_A=0; row_A < A_n_rows; ++row_A)
98 {
99 tmp.copy_row(A, row_A);
100
101 for(uword col_B=0; col_B < B_n_cols; ++col_B)
102 {
103 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
104
105 if( (use_alpha == false) && (use_beta == false) )
106 {
107 C.at(row_A,col_B) = acc;
108 }
109 else
110 if( (use_alpha == true) && (use_beta == false) )
111 {
112 C.at(row_A,col_B) = alpha * acc;
113 }
114 else
115 if( (use_alpha == false) && (use_beta == true) )
116 {
117 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
118 }
119 else
120 if( (use_alpha == true) && (use_beta == true) )
121 {
122 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
123 }
124
125 }
126 }
127 }
128 else
129 if( (do_trans_A == true) && (do_trans_B == false) )
130 {
131 for(uword col_A=0; col_A < A_n_cols; ++col_A)
132 {
133 // col_A is interpreted as row_A when storing the results in matrix C
134
135 const eT* A_coldata = A.colptr(col_A);
136
137 for(uword col_B=0; col_B < B_n_cols; ++col_B)
138 {
139 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
140
141 if( (use_alpha == false) && (use_beta == false) )
142 {
143 C.at(col_A,col_B) = acc;
144 }
145 else
146 if( (use_alpha == true) && (use_beta == false) )
147 {
148 C.at(col_A,col_B) = alpha * acc;
149 }
150 else
151 if( (use_alpha == false) && (use_beta == true) )
152 {
153 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
154 }
155 else
156 if( (use_alpha == true) && (use_beta == true) )
157 {
158 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
159 }
160
161 }
162 }
163 }
164 else
165 if( (do_trans_A == false) && (do_trans_B == true) )
166 {
167 Mat<eT> BB;
168 op_strans::apply_noalias(BB, B);
169
170 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
171 }
172 else
173 if( (do_trans_A == true) && (do_trans_B == true) )
174 {
175 // mat B_tmp = trans(B);
176 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
177
178
179 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
180 // transpose operations are not needed
181
182 arma_aligned podarray<eT> tmp(B.n_cols);
183 eT* B_rowdata = tmp.memptr();
184
185 for(uword row_B=0; row_B < B_n_rows; ++row_B)
186 {
187 tmp.copy_row(B, row_B);
188
189 for(uword col_A=0; col_A < A_n_cols; ++col_A)
190 {
191 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
192
193 if( (use_alpha == false) && (use_beta == false) )
194 {
195 C.at(col_A,row_B) = acc;
196 }
197 else
198 if( (use_alpha == true) && (use_beta == false) )
199 {
200 C.at(col_A,row_B) = alpha * acc;
201 }
202 else
203 if( (use_alpha == false) && (use_beta == true) )
204 {
205 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
206 }
207 else
208 if( (use_alpha == true) && (use_beta == true) )
209 {
210 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
211 }
212
213 }
214 }
215
216 }
217 }
218
219 };
220
221
222
223 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
224 class gemm_emul
225 {
226 public:
227
228
229 template<typename eT>
230 arma_hot
231 inline
232 static
233 void
234 apply
235 (
236 Mat<eT>& C,
237 const Mat<eT>& A,
238 const Mat<eT>& B,
239 const eT alpha = eT(1),
240 const eT beta = eT(0),
241 const typename arma_not_cx<eT>::result* junk = 0
242 )
243 {
244 arma_extra_debug_sigprint();
245 arma_ignore(junk);
246
247 const uword A_n_rows = A.n_rows;
248 const uword A_n_cols = A.n_cols;
249
250 const uword B_n_rows = B.n_rows;
251 const uword B_n_cols = B.n_cols;
252
253 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
254 {
255 if(do_trans_B == false)
256 {
257 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
258 }
259 else
260 {
261 Mat<eT> BB(A_n_rows, A_n_rows);
262 op_strans::apply_noalias_tinysq(BB, B);
263
264 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
265 }
266 }
267 else
268 {
269 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
270 }
271 }
272
273
274
275 template<typename eT>
276 arma_hot
277 inline
278 static
279 void
280 apply
281 (
282 Mat<eT>& C,
283 const Mat<eT>& A,
284 const Mat<eT>& B,
285 const eT alpha = eT(1),
286 const eT beta = eT(0),
287 const typename arma_cx_only<eT>::result* junk = 0
288 )
289 {
290 arma_extra_debug_sigprint();
291 arma_ignore(junk);
292
293 // "better than nothing" handling of hermitian transposes for complex number matrices
294
295 Mat<eT> tmp_A;
296 Mat<eT> tmp_B;
297
298 if(do_trans_A)
299 {
300 op_htrans::apply_noalias(tmp_A, A);
301 }
302
303 if(do_trans_B)
304 {
305 op_htrans::apply_noalias(tmp_B, B);
306 }
307
308 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
309 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
310
311 const uword A_n_rows = AA.n_rows;
312 const uword A_n_cols = AA.n_cols;
313
314 const uword B_n_rows = BB.n_rows;
315 const uword B_n_cols = BB.n_cols;
316
317 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
318 {
319 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
320 }
321 else
322 {
323 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
324 }
325 }
326
327 };
328
329
330
331 //! \brief
332 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
333 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
334
335 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
336 class gemm
337 {
338 public:
339
340 template<typename eT>
341 inline
342 static
343 void
344 apply_blas_type( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
345 {
346 arma_extra_debug_sigprint();
347
348 if( (A.n_elem <= 48u) && (B.n_elem <= 48u) )
349 {
350 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
351 }
352 else
353 {
354 #if defined(ARMA_USE_ATLAS)
355 {
356 arma_extra_debug_print("atlas::cblas_gemm()");
357
358 atlas::cblas_gemm<eT>
359 (
360 atlas::CblasColMajor,
361 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
362 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
363 C.n_rows,
364 C.n_cols,
365 (do_trans_A) ? A.n_rows : A.n_cols,
366 (use_alpha) ? alpha : eT(1),
367 A.mem,
368 (do_trans_A) ? A.n_rows : C.n_rows,
369 B.mem,
370 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
371 (use_beta) ? beta : eT(0),
372 C.memptr(),
373 C.n_rows
374 );
375 }
376 #elif defined(ARMA_USE_BLAS)
377 {
378 arma_extra_debug_print("blas::gemm()");
379
380 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
381 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
382
383 const blas_int m = C.n_rows;
384 const blas_int n = C.n_cols;
385 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
386
387 const eT local_alpha = (use_alpha) ? alpha : eT(1);
388
389 const blas_int lda = (do_trans_A) ? k : m;
390 const blas_int ldb = (do_trans_B) ? n : k;
391
392 const eT local_beta = (use_beta) ? beta : eT(0);
393
394 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
395 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
396
397 blas::gemm<eT>
398 (
399 &trans_A,
400 &trans_B,
401 &m,
402 &n,
403 &k,
404 &local_alpha,
405 A.mem,
406 &lda,
407 B.mem,
408 &ldb,
409 &local_beta,
410 C.memptr(),
411 &m
412 );
413 }
414 #else
415 {
416 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
417 }
418 #endif
419 }
420 }
421
422
423
424 //! immediate multiplication of matrices A and B, storing the result in C
425 template<typename eT>
426 inline
427 static
428 void
429 apply( Mat<eT>& C, const Mat<eT>& A, const Mat<eT>& B, const eT alpha = eT(1), const eT beta = eT(0) )
430 {
431 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
432 }
433
434
435
436 arma_inline
437 static
438 void
439 apply
440 (
441 Mat<float>& C,
442 const Mat<float>& A,
443 const Mat<float>& B,
444 const float alpha = float(1),
445 const float beta = float(0)
446 )
447 {
448 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
449 }
450
451
452
453 arma_inline
454 static
455 void
456 apply
457 (
458 Mat<double>& C,
459 const Mat<double>& A,
460 const Mat<double>& B,
461 const double alpha = double(1),
462 const double beta = double(0)
463 )
464 {
465 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
466 }
467
468
469
470 arma_inline
471 static
472 void
473 apply
474 (
475 Mat< std::complex<float> >& C,
476 const Mat< std::complex<float> >& A,
477 const Mat< std::complex<float> >& B,
478 const std::complex<float> alpha = std::complex<float>(1),
479 const std::complex<float> beta = std::complex<float>(0)
480 )
481 {
482 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
483 }
484
485
486
487 arma_inline
488 static
489 void
490 apply
491 (
492 Mat< std::complex<double> >& C,
493 const Mat< std::complex<double> >& A,
494 const Mat< std::complex<double> >& B,
495 const std::complex<double> alpha = std::complex<double>(1),
496 const std::complex<double> beta = std::complex<double>(0)
497 )
498 {
499 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
500 }
501
502 };
503
504
505
506 //! @}