Chris@49
|
1 // Copyright (C) 2008-2013 NICTA (www.nicta.com.au)
|
Chris@49
|
2 // Copyright (C) 2008-2013 Conrad Sanderson
|
Chris@49
|
3 //
|
Chris@49
|
4 // This Source Code Form is subject to the terms of the Mozilla Public
|
Chris@49
|
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
|
Chris@49
|
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
Chris@49
|
7
|
Chris@49
|
8
|
Chris@49
|
9 //! \addtogroup gemm
|
Chris@49
|
10 //! @{
|
Chris@49
|
11
|
Chris@49
|
12
|
Chris@49
|
13
|
Chris@49
|
14 //! for tiny square matrices, size <= 4x4
|
Chris@49
|
15 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
|
Chris@49
|
16 class gemm_emul_tinysq
|
Chris@49
|
17 {
|
Chris@49
|
18 public:
|
Chris@49
|
19
|
Chris@49
|
20
|
Chris@49
|
21 template<typename eT, typename TA, typename TB>
|
Chris@49
|
22 arma_hot
|
Chris@49
|
23 inline
|
Chris@49
|
24 static
|
Chris@49
|
25 void
|
Chris@49
|
26 apply
|
Chris@49
|
27 (
|
Chris@49
|
28 Mat<eT>& C,
|
Chris@49
|
29 const TA& A,
|
Chris@49
|
30 const TB& B,
|
Chris@49
|
31 const eT alpha = eT(1),
|
Chris@49
|
32 const eT beta = eT(0)
|
Chris@49
|
33 )
|
Chris@49
|
34 {
|
Chris@49
|
35 arma_extra_debug_sigprint();
|
Chris@49
|
36
|
Chris@49
|
37 switch(A.n_rows)
|
Chris@49
|
38 {
|
Chris@49
|
39 case 4:
|
Chris@49
|
40 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(3), A, B.colptr(3), alpha, beta );
|
Chris@49
|
41
|
Chris@49
|
42 case 3:
|
Chris@49
|
43 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(2), A, B.colptr(2), alpha, beta );
|
Chris@49
|
44
|
Chris@49
|
45 case 2:
|
Chris@49
|
46 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(1), A, B.colptr(1), alpha, beta );
|
Chris@49
|
47
|
Chris@49
|
48 case 1:
|
Chris@49
|
49 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply( C.colptr(0), A, B.colptr(0), alpha, beta );
|
Chris@49
|
50
|
Chris@49
|
51 default:
|
Chris@49
|
52 ;
|
Chris@49
|
53 }
|
Chris@49
|
54 }
|
Chris@49
|
55
|
Chris@49
|
56 };
|
Chris@49
|
57
|
Chris@49
|
58
|
Chris@49
|
59
|
Chris@49
|
60 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
Chris@49
|
61 class gemm_emul_large
|
Chris@49
|
62 {
|
Chris@49
|
63 public:
|
Chris@49
|
64
|
Chris@49
|
65 template<typename eT, typename TA, typename TB>
|
Chris@49
|
66 arma_hot
|
Chris@49
|
67 inline
|
Chris@49
|
68 static
|
Chris@49
|
69 void
|
Chris@49
|
70 apply
|
Chris@49
|
71 (
|
Chris@49
|
72 Mat<eT>& C,
|
Chris@49
|
73 const TA& A,
|
Chris@49
|
74 const TB& B,
|
Chris@49
|
75 const eT alpha = eT(1),
|
Chris@49
|
76 const eT beta = eT(0)
|
Chris@49
|
77 )
|
Chris@49
|
78 {
|
Chris@49
|
79 arma_extra_debug_sigprint();
|
Chris@49
|
80
|
Chris@49
|
81 const uword A_n_rows = A.n_rows;
|
Chris@49
|
82 const uword A_n_cols = A.n_cols;
|
Chris@49
|
83
|
Chris@49
|
84 const uword B_n_rows = B.n_rows;
|
Chris@49
|
85 const uword B_n_cols = B.n_cols;
|
Chris@49
|
86
|
Chris@49
|
87 if( (do_trans_A == false) && (do_trans_B == false) )
|
Chris@49
|
88 {
|
Chris@49
|
89 arma_aligned podarray<eT> tmp(A_n_cols);
|
Chris@49
|
90
|
Chris@49
|
91 eT* A_rowdata = tmp.memptr();
|
Chris@49
|
92
|
Chris@49
|
93 for(uword row_A=0; row_A < A_n_rows; ++row_A)
|
Chris@49
|
94 {
|
Chris@49
|
95 //tmp.copy_row(A, row_A);
|
Chris@49
|
96 const eT acc0 = op_dot::dot_and_copy_row(A_rowdata, A, row_A, B.colptr(0), A_n_cols);
|
Chris@49
|
97
|
Chris@49
|
98 if( (use_alpha == false) && (use_beta == false) )
|
Chris@49
|
99 {
|
Chris@49
|
100 C.at(row_A,0) = acc0;
|
Chris@49
|
101 }
|
Chris@49
|
102 else
|
Chris@49
|
103 if( (use_alpha == true) && (use_beta == false) )
|
Chris@49
|
104 {
|
Chris@49
|
105 C.at(row_A,0) = alpha * acc0;
|
Chris@49
|
106 }
|
Chris@49
|
107 else
|
Chris@49
|
108 if( (use_alpha == false) && (use_beta == true) )
|
Chris@49
|
109 {
|
Chris@49
|
110 C.at(row_A,0) = acc0 + beta*C.at(row_A,0);
|
Chris@49
|
111 }
|
Chris@49
|
112 else
|
Chris@49
|
113 if( (use_alpha == true) && (use_beta == true) )
|
Chris@49
|
114 {
|
Chris@49
|
115 C.at(row_A,0) = alpha*acc0 + beta*C.at(row_A,0);
|
Chris@49
|
116 }
|
Chris@49
|
117
|
Chris@49
|
118 //for(uword col_B=0; col_B < B_n_cols; ++col_B)
|
Chris@49
|
119 for(uword col_B=1; col_B < B_n_cols; ++col_B)
|
Chris@49
|
120 {
|
Chris@49
|
121 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_rowdata, B.colptr(col_B));
|
Chris@49
|
122
|
Chris@49
|
123 if( (use_alpha == false) && (use_beta == false) )
|
Chris@49
|
124 {
|
Chris@49
|
125 C.at(row_A,col_B) = acc;
|
Chris@49
|
126 }
|
Chris@49
|
127 else
|
Chris@49
|
128 if( (use_alpha == true) && (use_beta == false) )
|
Chris@49
|
129 {
|
Chris@49
|
130 C.at(row_A,col_B) = alpha * acc;
|
Chris@49
|
131 }
|
Chris@49
|
132 else
|
Chris@49
|
133 if( (use_alpha == false) && (use_beta == true) )
|
Chris@49
|
134 {
|
Chris@49
|
135 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
|
Chris@49
|
136 }
|
Chris@49
|
137 else
|
Chris@49
|
138 if( (use_alpha == true) && (use_beta == true) )
|
Chris@49
|
139 {
|
Chris@49
|
140 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
|
Chris@49
|
141 }
|
Chris@49
|
142
|
Chris@49
|
143 }
|
Chris@49
|
144 }
|
Chris@49
|
145 }
|
Chris@49
|
146 else
|
Chris@49
|
147 if( (do_trans_A == true) && (do_trans_B == false) )
|
Chris@49
|
148 {
|
Chris@49
|
149 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
Chris@49
|
150 {
|
Chris@49
|
151 // col_A is interpreted as row_A when storing the results in matrix C
|
Chris@49
|
152
|
Chris@49
|
153 const eT* A_coldata = A.colptr(col_A);
|
Chris@49
|
154
|
Chris@49
|
155 for(uword col_B=0; col_B < B_n_cols; ++col_B)
|
Chris@49
|
156 {
|
Chris@49
|
157 const eT acc = op_dot::direct_dot_arma(B_n_rows, A_coldata, B.colptr(col_B));
|
Chris@49
|
158
|
Chris@49
|
159 if( (use_alpha == false) && (use_beta == false) )
|
Chris@49
|
160 {
|
Chris@49
|
161 C.at(col_A,col_B) = acc;
|
Chris@49
|
162 }
|
Chris@49
|
163 else
|
Chris@49
|
164 if( (use_alpha == true) && (use_beta == false) )
|
Chris@49
|
165 {
|
Chris@49
|
166 C.at(col_A,col_B) = alpha * acc;
|
Chris@49
|
167 }
|
Chris@49
|
168 else
|
Chris@49
|
169 if( (use_alpha == false) && (use_beta == true) )
|
Chris@49
|
170 {
|
Chris@49
|
171 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
|
Chris@49
|
172 }
|
Chris@49
|
173 else
|
Chris@49
|
174 if( (use_alpha == true) && (use_beta == true) )
|
Chris@49
|
175 {
|
Chris@49
|
176 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
|
Chris@49
|
177 }
|
Chris@49
|
178
|
Chris@49
|
179 }
|
Chris@49
|
180 }
|
Chris@49
|
181 }
|
Chris@49
|
182 else
|
Chris@49
|
183 if( (do_trans_A == false) && (do_trans_B == true) )
|
Chris@49
|
184 {
|
Chris@49
|
185 Mat<eT> BB;
|
Chris@49
|
186 op_strans::apply_noalias(BB, B);
|
Chris@49
|
187
|
Chris@49
|
188 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
|
Chris@49
|
189 }
|
Chris@49
|
190 else
|
Chris@49
|
191 if( (do_trans_A == true) && (do_trans_B == true) )
|
Chris@49
|
192 {
|
Chris@49
|
193 // mat B_tmp = trans(B);
|
Chris@49
|
194 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
|
Chris@49
|
195
|
Chris@49
|
196
|
Chris@49
|
197 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
|
Chris@49
|
198 // transpose operations are not needed
|
Chris@49
|
199
|
Chris@49
|
200 arma_aligned podarray<eT> tmp(B.n_cols);
|
Chris@49
|
201 eT* B_rowdata = tmp.memptr();
|
Chris@49
|
202
|
Chris@49
|
203 for(uword row_B=0; row_B < B_n_rows; ++row_B)
|
Chris@49
|
204 {
|
Chris@49
|
205 tmp.copy_row(B, row_B);
|
Chris@49
|
206
|
Chris@49
|
207 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
Chris@49
|
208 {
|
Chris@49
|
209 const eT acc = op_dot::direct_dot_arma(A_n_rows, B_rowdata, A.colptr(col_A));
|
Chris@49
|
210
|
Chris@49
|
211 if( (use_alpha == false) && (use_beta == false) )
|
Chris@49
|
212 {
|
Chris@49
|
213 C.at(col_A,row_B) = acc;
|
Chris@49
|
214 }
|
Chris@49
|
215 else
|
Chris@49
|
216 if( (use_alpha == true) && (use_beta == false) )
|
Chris@49
|
217 {
|
Chris@49
|
218 C.at(col_A,row_B) = alpha * acc;
|
Chris@49
|
219 }
|
Chris@49
|
220 else
|
Chris@49
|
221 if( (use_alpha == false) && (use_beta == true) )
|
Chris@49
|
222 {
|
Chris@49
|
223 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
|
Chris@49
|
224 }
|
Chris@49
|
225 else
|
Chris@49
|
226 if( (use_alpha == true) && (use_beta == true) )
|
Chris@49
|
227 {
|
Chris@49
|
228 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
|
Chris@49
|
229 }
|
Chris@49
|
230
|
Chris@49
|
231 }
|
Chris@49
|
232 }
|
Chris@49
|
233
|
Chris@49
|
234 }
|
Chris@49
|
235 }
|
Chris@49
|
236
|
Chris@49
|
237 };
|
Chris@49
|
238
|
Chris@49
|
239
|
Chris@49
|
240
|
Chris@49
|
241 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
Chris@49
|
242 class gemm_emul
|
Chris@49
|
243 {
|
Chris@49
|
244 public:
|
Chris@49
|
245
|
Chris@49
|
246
|
Chris@49
|
247 template<typename eT, typename TA, typename TB>
|
Chris@49
|
248 arma_hot
|
Chris@49
|
249 inline
|
Chris@49
|
250 static
|
Chris@49
|
251 void
|
Chris@49
|
252 apply
|
Chris@49
|
253 (
|
Chris@49
|
254 Mat<eT>& C,
|
Chris@49
|
255 const TA& A,
|
Chris@49
|
256 const TB& B,
|
Chris@49
|
257 const eT alpha = eT(1),
|
Chris@49
|
258 const eT beta = eT(0),
|
Chris@49
|
259 const typename arma_not_cx<eT>::result* junk = 0
|
Chris@49
|
260 )
|
Chris@49
|
261 {
|
Chris@49
|
262 arma_extra_debug_sigprint();
|
Chris@49
|
263 arma_ignore(junk);
|
Chris@49
|
264
|
Chris@49
|
265 const uword A_n_rows = A.n_rows;
|
Chris@49
|
266 const uword A_n_cols = A.n_cols;
|
Chris@49
|
267
|
Chris@49
|
268 const uword B_n_rows = B.n_rows;
|
Chris@49
|
269 const uword B_n_cols = B.n_cols;
|
Chris@49
|
270
|
Chris@49
|
271 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
|
Chris@49
|
272 {
|
Chris@49
|
273 if(do_trans_B == false)
|
Chris@49
|
274 {
|
Chris@49
|
275 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
|
Chris@49
|
276 }
|
Chris@49
|
277 else
|
Chris@49
|
278 {
|
Chris@49
|
279 Mat<eT> BB(A_n_rows, A_n_rows);
|
Chris@49
|
280 op_strans::apply_noalias_tinysq(BB, B);
|
Chris@49
|
281
|
Chris@49
|
282 gemm_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(C, A, BB, alpha, beta);
|
Chris@49
|
283 }
|
Chris@49
|
284 }
|
Chris@49
|
285 else
|
Chris@49
|
286 {
|
Chris@49
|
287 gemm_emul_large<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C, A, B, alpha, beta);
|
Chris@49
|
288 }
|
Chris@49
|
289 }
|
Chris@49
|
290
|
Chris@49
|
291
|
Chris@49
|
292
|
Chris@49
|
293 template<typename eT>
|
Chris@49
|
294 arma_hot
|
Chris@49
|
295 inline
|
Chris@49
|
296 static
|
Chris@49
|
297 void
|
Chris@49
|
298 apply
|
Chris@49
|
299 (
|
Chris@49
|
300 Mat<eT>& C,
|
Chris@49
|
301 const Mat<eT>& A,
|
Chris@49
|
302 const Mat<eT>& B,
|
Chris@49
|
303 const eT alpha = eT(1),
|
Chris@49
|
304 const eT beta = eT(0),
|
Chris@49
|
305 const typename arma_cx_only<eT>::result* junk = 0
|
Chris@49
|
306 )
|
Chris@49
|
307 {
|
Chris@49
|
308 arma_extra_debug_sigprint();
|
Chris@49
|
309 arma_ignore(junk);
|
Chris@49
|
310
|
Chris@49
|
311 // "better than nothing" handling of hermitian transposes for complex number matrices
|
Chris@49
|
312
|
Chris@49
|
313 Mat<eT> tmp_A;
|
Chris@49
|
314 Mat<eT> tmp_B;
|
Chris@49
|
315
|
Chris@49
|
316 if(do_trans_A)
|
Chris@49
|
317 {
|
Chris@49
|
318 op_htrans::apply_noalias(tmp_A, A);
|
Chris@49
|
319 }
|
Chris@49
|
320
|
Chris@49
|
321 if(do_trans_B)
|
Chris@49
|
322 {
|
Chris@49
|
323 op_htrans::apply_noalias(tmp_B, B);
|
Chris@49
|
324 }
|
Chris@49
|
325
|
Chris@49
|
326 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
|
Chris@49
|
327 const Mat<eT>& BB = (do_trans_B == false) ? B : tmp_B;
|
Chris@49
|
328
|
Chris@49
|
329 const uword A_n_rows = AA.n_rows;
|
Chris@49
|
330 const uword A_n_cols = AA.n_cols;
|
Chris@49
|
331
|
Chris@49
|
332 const uword B_n_rows = BB.n_rows;
|
Chris@49
|
333 const uword B_n_cols = BB.n_cols;
|
Chris@49
|
334
|
Chris@49
|
335 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) && (A_n_rows == B_n_rows) && (B_n_rows == B_n_cols) )
|
Chris@49
|
336 {
|
Chris@49
|
337 gemm_emul_tinysq<false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
|
Chris@49
|
338 }
|
Chris@49
|
339 else
|
Chris@49
|
340 {
|
Chris@49
|
341 gemm_emul_large<false, false, use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
|
Chris@49
|
342 }
|
Chris@49
|
343 }
|
Chris@49
|
344
|
Chris@49
|
345 };
|
Chris@49
|
346
|
Chris@49
|
347
|
Chris@49
|
348
|
Chris@49
|
349 //! \brief
|
Chris@49
|
350 //! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm.
|
Chris@49
|
351 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
|
Chris@49
|
352
|
Chris@49
|
353 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
Chris@49
|
354 class gemm
|
Chris@49
|
355 {
|
Chris@49
|
356 public:
|
Chris@49
|
357
|
Chris@49
|
358 template<typename eT, typename TA, typename TB>
|
Chris@49
|
359 inline
|
Chris@49
|
360 static
|
Chris@49
|
361 void
|
Chris@49
|
362 apply_blas_type( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
|
Chris@49
|
363 {
|
Chris@49
|
364 arma_extra_debug_sigprint();
|
Chris@49
|
365
|
Chris@49
|
366 const uword threshold = (is_Mat_fixed<TA>::value && is_Mat_fixed<TB>::value)
|
Chris@49
|
367 ? (is_complex<eT>::value ? 16u : 64u)
|
Chris@49
|
368 : (is_complex<eT>::value ? 16u : 48u);
|
Chris@49
|
369
|
Chris@49
|
370 if( (A.n_elem <= threshold) && (B.n_elem <= threshold) )
|
Chris@49
|
371 {
|
Chris@49
|
372 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
|
Chris@49
|
373 }
|
Chris@49
|
374 else
|
Chris@49
|
375 {
|
Chris@49
|
376 #if defined(ARMA_USE_ATLAS)
|
Chris@49
|
377 {
|
Chris@49
|
378 arma_extra_debug_print("atlas::cblas_gemm()");
|
Chris@49
|
379
|
Chris@49
|
380 atlas::cblas_gemm<eT>
|
Chris@49
|
381 (
|
Chris@49
|
382 atlas::CblasColMajor,
|
Chris@49
|
383 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
|
Chris@49
|
384 (do_trans_B) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
|
Chris@49
|
385 C.n_rows,
|
Chris@49
|
386 C.n_cols,
|
Chris@49
|
387 (do_trans_A) ? A.n_rows : A.n_cols,
|
Chris@49
|
388 (use_alpha) ? alpha : eT(1),
|
Chris@49
|
389 A.mem,
|
Chris@49
|
390 (do_trans_A) ? A.n_rows : C.n_rows,
|
Chris@49
|
391 B.mem,
|
Chris@49
|
392 (do_trans_B) ? C.n_cols : ( (do_trans_A) ? A.n_rows : A.n_cols ),
|
Chris@49
|
393 (use_beta) ? beta : eT(0),
|
Chris@49
|
394 C.memptr(),
|
Chris@49
|
395 C.n_rows
|
Chris@49
|
396 );
|
Chris@49
|
397 }
|
Chris@49
|
398 #elif defined(ARMA_USE_BLAS)
|
Chris@49
|
399 {
|
Chris@49
|
400 arma_extra_debug_print("blas::gemm()");
|
Chris@49
|
401
|
Chris@49
|
402 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
|
Chris@49
|
403 const char trans_B = (do_trans_B) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
|
Chris@49
|
404
|
Chris@49
|
405 const blas_int m = C.n_rows;
|
Chris@49
|
406 const blas_int n = C.n_cols;
|
Chris@49
|
407 const blas_int k = (do_trans_A) ? A.n_rows : A.n_cols;
|
Chris@49
|
408
|
Chris@49
|
409 const eT local_alpha = (use_alpha) ? alpha : eT(1);
|
Chris@49
|
410
|
Chris@49
|
411 const blas_int lda = (do_trans_A) ? k : m;
|
Chris@49
|
412 const blas_int ldb = (do_trans_B) ? n : k;
|
Chris@49
|
413
|
Chris@49
|
414 const eT local_beta = (use_beta) ? beta : eT(0);
|
Chris@49
|
415
|
Chris@49
|
416 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_A = %c") % trans_A );
|
Chris@49
|
417 arma_extra_debug_print( arma_boost::format("blas::gemm(): trans_B = %c") % trans_B );
|
Chris@49
|
418
|
Chris@49
|
419 blas::gemm<eT>
|
Chris@49
|
420 (
|
Chris@49
|
421 &trans_A,
|
Chris@49
|
422 &trans_B,
|
Chris@49
|
423 &m,
|
Chris@49
|
424 &n,
|
Chris@49
|
425 &k,
|
Chris@49
|
426 &local_alpha,
|
Chris@49
|
427 A.mem,
|
Chris@49
|
428 &lda,
|
Chris@49
|
429 B.mem,
|
Chris@49
|
430 &ldb,
|
Chris@49
|
431 &local_beta,
|
Chris@49
|
432 C.memptr(),
|
Chris@49
|
433 &m
|
Chris@49
|
434 );
|
Chris@49
|
435 }
|
Chris@49
|
436 #else
|
Chris@49
|
437 {
|
Chris@49
|
438 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
|
Chris@49
|
439 }
|
Chris@49
|
440 #endif
|
Chris@49
|
441 }
|
Chris@49
|
442 }
|
Chris@49
|
443
|
Chris@49
|
444
|
Chris@49
|
445
|
Chris@49
|
446 //! immediate multiplication of matrices A and B, storing the result in C
|
Chris@49
|
447 template<typename eT, typename TA, typename TB>
|
Chris@49
|
448 inline
|
Chris@49
|
449 static
|
Chris@49
|
450 void
|
Chris@49
|
451 apply( Mat<eT>& C, const TA& A, const TB& B, const eT alpha = eT(1), const eT beta = eT(0) )
|
Chris@49
|
452 {
|
Chris@49
|
453 gemm_emul<do_trans_A, do_trans_B, use_alpha, use_beta>::apply(C,A,B,alpha,beta);
|
Chris@49
|
454 }
|
Chris@49
|
455
|
Chris@49
|
456
|
Chris@49
|
457
|
Chris@49
|
458 template<typename TA, typename TB>
|
Chris@49
|
459 arma_inline
|
Chris@49
|
460 static
|
Chris@49
|
461 void
|
Chris@49
|
462 apply
|
Chris@49
|
463 (
|
Chris@49
|
464 Mat<float>& C,
|
Chris@49
|
465 const TA& A,
|
Chris@49
|
466 const TB& B,
|
Chris@49
|
467 const float alpha = float(1),
|
Chris@49
|
468 const float beta = float(0)
|
Chris@49
|
469 )
|
Chris@49
|
470 {
|
Chris@49
|
471 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
|
Chris@49
|
472 }
|
Chris@49
|
473
|
Chris@49
|
474
|
Chris@49
|
475
|
Chris@49
|
476 template<typename TA, typename TB>
|
Chris@49
|
477 arma_inline
|
Chris@49
|
478 static
|
Chris@49
|
479 void
|
Chris@49
|
480 apply
|
Chris@49
|
481 (
|
Chris@49
|
482 Mat<double>& C,
|
Chris@49
|
483 const TA& A,
|
Chris@49
|
484 const TB& B,
|
Chris@49
|
485 const double alpha = double(1),
|
Chris@49
|
486 const double beta = double(0)
|
Chris@49
|
487 )
|
Chris@49
|
488 {
|
Chris@49
|
489 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
|
Chris@49
|
490 }
|
Chris@49
|
491
|
Chris@49
|
492
|
Chris@49
|
493
|
Chris@49
|
494 template<typename TA, typename TB>
|
Chris@49
|
495 arma_inline
|
Chris@49
|
496 static
|
Chris@49
|
497 void
|
Chris@49
|
498 apply
|
Chris@49
|
499 (
|
Chris@49
|
500 Mat< std::complex<float> >& C,
|
Chris@49
|
501 const TA& A,
|
Chris@49
|
502 const TB& B,
|
Chris@49
|
503 const std::complex<float> alpha = std::complex<float>(1),
|
Chris@49
|
504 const std::complex<float> beta = std::complex<float>(0)
|
Chris@49
|
505 )
|
Chris@49
|
506 {
|
Chris@49
|
507 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
|
Chris@49
|
508 }
|
Chris@49
|
509
|
Chris@49
|
510
|
Chris@49
|
511
|
Chris@49
|
512 template<typename TA, typename TB>
|
Chris@49
|
513 arma_inline
|
Chris@49
|
514 static
|
Chris@49
|
515 void
|
Chris@49
|
516 apply
|
Chris@49
|
517 (
|
Chris@49
|
518 Mat< std::complex<double> >& C,
|
Chris@49
|
519 const TA& A,
|
Chris@49
|
520 const TB& B,
|
Chris@49
|
521 const std::complex<double> alpha = std::complex<double>(1),
|
Chris@49
|
522 const std::complex<double> beta = std::complex<double>(0)
|
Chris@49
|
523 )
|
Chris@49
|
524 {
|
Chris@49
|
525 gemm<do_trans_A, do_trans_B, use_alpha, use_beta>::apply_blas_type(C,A,B,alpha,beta);
|
Chris@49
|
526 }
|
Chris@49
|
527
|
Chris@49
|
528 };
|
Chris@49
|
529
|
Chris@49
|
530
|
Chris@49
|
531
|
Chris@49
|
532 //! @}
|