comparison armadillo-3.900.4/include/armadillo_bits/gemm_mixed.hpp @ 49:1ec0e2823891

Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author Chris Cannam
date Thu, 13 Jun 2013 10:25:24 +0100
parents
children
comparison
equal deleted inserted replaced
48:69251e11a913 49:1ec0e2823891
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This Source Code Form is subject to the terms of the Mozilla Public
5 // License, v. 2.0. If a copy of the MPL was not distributed with this
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/.
7
8
9 //! \addtogroup gemm_mixed
10 //! @{
11
12
13
14 //! \brief
15 //! Matrix multplication where the matrices have differing element types.
16 //! Uses caching for speedup.
17 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
18
19 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
20 class gemm_mixed_large
21 {
22 public:
23
24 template<typename out_eT, typename in_eT1, typename in_eT2>
25 arma_hot
26 inline
27 static
28 void
29 apply
30 (
31 Mat<out_eT>& C,
32 const Mat<in_eT1>& A,
33 const Mat<in_eT2>& B,
34 const out_eT alpha = out_eT(1),
35 const out_eT beta = out_eT(0)
36 )
37 {
38 arma_extra_debug_sigprint();
39
40 const uword A_n_rows = A.n_rows;
41 const uword A_n_cols = A.n_cols;
42
43 const uword B_n_rows = B.n_rows;
44 const uword B_n_cols = B.n_cols;
45
46 if( (do_trans_A == false) && (do_trans_B == false) )
47 {
48 podarray<in_eT1> tmp(A_n_cols);
49 in_eT1* A_rowdata = tmp.memptr();
50
51 for(uword row_A=0; row_A < A_n_rows; ++row_A)
52 {
53 tmp.copy_row(A, row_A);
54
55 for(uword col_B=0; col_B < B_n_cols; ++col_B)
56 {
57 const in_eT2* B_coldata = B.colptr(col_B);
58
59 out_eT acc = out_eT(0);
60 for(uword i=0; i < B_n_rows; ++i)
61 {
62 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
63 }
64
65 if( (use_alpha == false) && (use_beta == false) )
66 {
67 C.at(row_A,col_B) = acc;
68 }
69 else
70 if( (use_alpha == true) && (use_beta == false) )
71 {
72 C.at(row_A,col_B) = alpha * acc;
73 }
74 else
75 if( (use_alpha == false) && (use_beta == true) )
76 {
77 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
78 }
79 else
80 if( (use_alpha == true) && (use_beta == true) )
81 {
82 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
83 }
84
85 }
86 }
87 }
88 else
89 if( (do_trans_A == true) && (do_trans_B == false) )
90 {
91 for(uword col_A=0; col_A < A_n_cols; ++col_A)
92 {
93 // col_A is interpreted as row_A when storing the results in matrix C
94
95 const in_eT1* A_coldata = A.colptr(col_A);
96
97 for(uword col_B=0; col_B < B_n_cols; ++col_B)
98 {
99 const in_eT2* B_coldata = B.colptr(col_B);
100
101 out_eT acc = out_eT(0);
102 for(uword i=0; i < B_n_rows; ++i)
103 {
104 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
105 }
106
107 if( (use_alpha == false) && (use_beta == false) )
108 {
109 C.at(col_A,col_B) = acc;
110 }
111 else
112 if( (use_alpha == true) && (use_beta == false) )
113 {
114 C.at(col_A,col_B) = alpha * acc;
115 }
116 else
117 if( (use_alpha == false) && (use_beta == true) )
118 {
119 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
120 }
121 else
122 if( (use_alpha == true) && (use_beta == true) )
123 {
124 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
125 }
126
127 }
128 }
129 }
130 else
131 if( (do_trans_A == false) && (do_trans_B == true) )
132 {
133 Mat<in_eT2> B_tmp;
134
135 op_strans::apply_noalias(B_tmp, B);
136
137 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
138 }
139 else
140 if( (do_trans_A == true) && (do_trans_B == true) )
141 {
142 // mat B_tmp = trans(B);
143 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
144
145
146 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
147 // transpose operations are not needed
148
149 podarray<in_eT2> tmp(B_n_cols);
150 in_eT2* B_rowdata = tmp.memptr();
151
152 for(uword row_B=0; row_B < B_n_rows; ++row_B)
153 {
154 tmp.copy_row(B, row_B);
155
156 for(uword col_A=0; col_A < A_n_cols; ++col_A)
157 {
158 const in_eT1* A_coldata = A.colptr(col_A);
159
160 out_eT acc = out_eT(0);
161 for(uword i=0; i < A_n_rows; ++i)
162 {
163 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
164 }
165
166 if( (use_alpha == false) && (use_beta == false) )
167 {
168 C.at(col_A,row_B) = acc;
169 }
170 else
171 if( (use_alpha == true) && (use_beta == false) )
172 {
173 C.at(col_A,row_B) = alpha * acc;
174 }
175 else
176 if( (use_alpha == false) && (use_beta == true) )
177 {
178 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
179 }
180 else
181 if( (use_alpha == true) && (use_beta == true) )
182 {
183 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
184 }
185
186 }
187 }
188
189 }
190 }
191
192 };
193
194
195
196 //! Matrix multplication where the matrices have different element types.
197 //! Simple version (no caching).
198 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
199 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
200 class gemm_mixed_small
201 {
202 public:
203
204 template<typename out_eT, typename in_eT1, typename in_eT2>
205 arma_hot
206 inline
207 static
208 void
209 apply
210 (
211 Mat<out_eT>& C,
212 const Mat<in_eT1>& A,
213 const Mat<in_eT2>& B,
214 const out_eT alpha = out_eT(1),
215 const out_eT beta = out_eT(0)
216 )
217 {
218 arma_extra_debug_sigprint();
219
220 const uword A_n_rows = A.n_rows;
221 const uword A_n_cols = A.n_cols;
222
223 const uword B_n_rows = B.n_rows;
224 const uword B_n_cols = B.n_cols;
225
226 if( (do_trans_A == false) && (do_trans_B == false) )
227 {
228 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
229 {
230 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
231 {
232 const in_eT2* B_coldata = B.colptr(col_B);
233
234 out_eT acc = out_eT(0);
235 for(uword i = 0; i < B_n_rows; ++i)
236 {
237 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
238 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
239 acc += val1 * val2;
240 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
241 }
242
243 if( (use_alpha == false) && (use_beta == false) )
244 {
245 C.at(row_A,col_B) = acc;
246 }
247 else
248 if( (use_alpha == true) && (use_beta == false) )
249 {
250 C.at(row_A,col_B) = alpha * acc;
251 }
252 else
253 if( (use_alpha == false) && (use_beta == true) )
254 {
255 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
256 }
257 else
258 if( (use_alpha == true) && (use_beta == true) )
259 {
260 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
261 }
262 }
263 }
264 }
265 else
266 if( (do_trans_A == true) && (do_trans_B == false) )
267 {
268 for(uword col_A=0; col_A < A_n_cols; ++col_A)
269 {
270 // col_A is interpreted as row_A when storing the results in matrix C
271
272 const in_eT1* A_coldata = A.colptr(col_A);
273
274 for(uword col_B=0; col_B < B_n_cols; ++col_B)
275 {
276 const in_eT2* B_coldata = B.colptr(col_B);
277
278 out_eT acc = out_eT(0);
279 for(uword i=0; i < B_n_rows; ++i)
280 {
281 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
282 }
283
284 if( (use_alpha == false) && (use_beta == false) )
285 {
286 C.at(col_A,col_B) = acc;
287 }
288 else
289 if( (use_alpha == true) && (use_beta == false) )
290 {
291 C.at(col_A,col_B) = alpha * acc;
292 }
293 else
294 if( (use_alpha == false) && (use_beta == true) )
295 {
296 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
297 }
298 else
299 if( (use_alpha == true) && (use_beta == true) )
300 {
301 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
302 }
303
304 }
305 }
306 }
307 else
308 if( (do_trans_A == false) && (do_trans_B == true) )
309 {
310 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
311 {
312 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
313 {
314 out_eT acc = out_eT(0);
315 for(uword i = 0; i < B_n_cols; ++i)
316 {
317 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
318 }
319
320 if( (use_alpha == false) && (use_beta == false) )
321 {
322 C.at(row_A,row_B) = acc;
323 }
324 else
325 if( (use_alpha == true) && (use_beta == false) )
326 {
327 C.at(row_A,row_B) = alpha * acc;
328 }
329 else
330 if( (use_alpha == false) && (use_beta == true) )
331 {
332 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
333 }
334 else
335 if( (use_alpha == true) && (use_beta == true) )
336 {
337 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
338 }
339 }
340 }
341 }
342 else
343 if( (do_trans_A == true) && (do_trans_B == true) )
344 {
345 for(uword row_B=0; row_B < B_n_rows; ++row_B)
346 {
347
348 for(uword col_A=0; col_A < A_n_cols; ++col_A)
349 {
350 const in_eT1* A_coldata = A.colptr(col_A);
351
352 out_eT acc = out_eT(0);
353 for(uword i=0; i < A_n_rows; ++i)
354 {
355 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
356 }
357
358 if( (use_alpha == false) && (use_beta == false) )
359 {
360 C.at(col_A,row_B) = acc;
361 }
362 else
363 if( (use_alpha == true) && (use_beta == false) )
364 {
365 C.at(col_A,row_B) = alpha * acc;
366 }
367 else
368 if( (use_alpha == false) && (use_beta == true) )
369 {
370 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
371 }
372 else
373 if( (use_alpha == true) && (use_beta == true) )
374 {
375 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
376 }
377
378 }
379 }
380
381 }
382 }
383
384 };
385
386
387
388
389
390 //! \brief
391 //! Matrix multplication where the matrices have differing element types.
392
393 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
394 class gemm_mixed
395 {
396 public:
397
398 //! immediate multiplication of matrices A and B, storing the result in C
399 template<typename out_eT, typename in_eT1, typename in_eT2>
400 inline
401 static
402 void
403 apply
404 (
405 Mat<out_eT>& C,
406 const Mat<in_eT1>& A,
407 const Mat<in_eT2>& B,
408 const out_eT alpha = out_eT(1),
409 const out_eT beta = out_eT(0)
410 )
411 {
412 arma_extra_debug_sigprint();
413
414 Mat<in_eT1> tmp_A;
415 Mat<in_eT2> tmp_B;
416
417 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
418 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
419
420 if(do_trans_A)
421 {
422 op_htrans::apply_noalias(tmp_A, A);
423 }
424
425 if(do_trans_B)
426 {
427 op_htrans::apply_noalias(tmp_B, B);
428 }
429
430 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
431 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
432
433 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
434 {
435 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
436 }
437 else
438 {
439 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
440 }
441 }
442
443
444 };
445
446
447
448 //! @}