Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-3.900.4/include/armadillo_bits/gemm_mixed.hpp @ 49:1ec0e2823891
Switch to using subrepo copies of qm-dsp, nnls-chroma, vamp-plugin-sdk; update Armadillo version; assume build without external BLAS/LAPACK
author | Chris Cannam |
---|---|
date | Thu, 13 Jun 2013 10:25:24 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
48:69251e11a913 | 49:1ec0e2823891 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This Source Code Form is subject to the terms of the Mozilla Public | |
5 // License, v. 2.0. If a copy of the MPL was not distributed with this | |
6 // file, You can obtain one at http://mozilla.org/MPL/2.0/. | |
7 | |
8 | |
9 //! \addtogroup gemm_mixed | |
10 //! @{ | |
11 | |
12 | |
13 | |
14 //! \brief | |
15 //! Matrix multplication where the matrices have differing element types. | |
16 //! Uses caching for speedup. | |
17 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
18 | |
19 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
20 class gemm_mixed_large | |
21 { | |
22 public: | |
23 | |
24 template<typename out_eT, typename in_eT1, typename in_eT2> | |
25 arma_hot | |
26 inline | |
27 static | |
28 void | |
29 apply | |
30 ( | |
31 Mat<out_eT>& C, | |
32 const Mat<in_eT1>& A, | |
33 const Mat<in_eT2>& B, | |
34 const out_eT alpha = out_eT(1), | |
35 const out_eT beta = out_eT(0) | |
36 ) | |
37 { | |
38 arma_extra_debug_sigprint(); | |
39 | |
40 const uword A_n_rows = A.n_rows; | |
41 const uword A_n_cols = A.n_cols; | |
42 | |
43 const uword B_n_rows = B.n_rows; | |
44 const uword B_n_cols = B.n_cols; | |
45 | |
46 if( (do_trans_A == false) && (do_trans_B == false) ) | |
47 { | |
48 podarray<in_eT1> tmp(A_n_cols); | |
49 in_eT1* A_rowdata = tmp.memptr(); | |
50 | |
51 for(uword row_A=0; row_A < A_n_rows; ++row_A) | |
52 { | |
53 tmp.copy_row(A, row_A); | |
54 | |
55 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
56 { | |
57 const in_eT2* B_coldata = B.colptr(col_B); | |
58 | |
59 out_eT acc = out_eT(0); | |
60 for(uword i=0; i < B_n_rows; ++i) | |
61 { | |
62 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
63 } | |
64 | |
65 if( (use_alpha == false) && (use_beta == false) ) | |
66 { | |
67 C.at(row_A,col_B) = acc; | |
68 } | |
69 else | |
70 if( (use_alpha == true) && (use_beta == false) ) | |
71 { | |
72 C.at(row_A,col_B) = alpha * acc; | |
73 } | |
74 else | |
75 if( (use_alpha == false) && (use_beta == true) ) | |
76 { | |
77 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
78 } | |
79 else | |
80 if( (use_alpha == true) && (use_beta == true) ) | |
81 { | |
82 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
83 } | |
84 | |
85 } | |
86 } | |
87 } | |
88 else | |
89 if( (do_trans_A == true) && (do_trans_B == false) ) | |
90 { | |
91 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
92 { | |
93 // col_A is interpreted as row_A when storing the results in matrix C | |
94 | |
95 const in_eT1* A_coldata = A.colptr(col_A); | |
96 | |
97 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
98 { | |
99 const in_eT2* B_coldata = B.colptr(col_B); | |
100 | |
101 out_eT acc = out_eT(0); | |
102 for(uword i=0; i < B_n_rows; ++i) | |
103 { | |
104 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
105 } | |
106 | |
107 if( (use_alpha == false) && (use_beta == false) ) | |
108 { | |
109 C.at(col_A,col_B) = acc; | |
110 } | |
111 else | |
112 if( (use_alpha == true) && (use_beta == false) ) | |
113 { | |
114 C.at(col_A,col_B) = alpha * acc; | |
115 } | |
116 else | |
117 if( (use_alpha == false) && (use_beta == true) ) | |
118 { | |
119 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
120 } | |
121 else | |
122 if( (use_alpha == true) && (use_beta == true) ) | |
123 { | |
124 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
125 } | |
126 | |
127 } | |
128 } | |
129 } | |
130 else | |
131 if( (do_trans_A == false) && (do_trans_B == true) ) | |
132 { | |
133 Mat<in_eT2> B_tmp; | |
134 | |
135 op_strans::apply_noalias(B_tmp, B); | |
136 | |
137 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
138 } | |
139 else | |
140 if( (do_trans_A == true) && (do_trans_B == true) ) | |
141 { | |
142 // mat B_tmp = trans(B); | |
143 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
144 | |
145 | |
146 // By using the trans(A)*trans(B) = trans(B*A) equivalency, | |
147 // transpose operations are not needed | |
148 | |
149 podarray<in_eT2> tmp(B_n_cols); | |
150 in_eT2* B_rowdata = tmp.memptr(); | |
151 | |
152 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
153 { | |
154 tmp.copy_row(B, row_B); | |
155 | |
156 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
157 { | |
158 const in_eT1* A_coldata = A.colptr(col_A); | |
159 | |
160 out_eT acc = out_eT(0); | |
161 for(uword i=0; i < A_n_rows; ++i) | |
162 { | |
163 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); | |
164 } | |
165 | |
166 if( (use_alpha == false) && (use_beta == false) ) | |
167 { | |
168 C.at(col_A,row_B) = acc; | |
169 } | |
170 else | |
171 if( (use_alpha == true) && (use_beta == false) ) | |
172 { | |
173 C.at(col_A,row_B) = alpha * acc; | |
174 } | |
175 else | |
176 if( (use_alpha == false) && (use_beta == true) ) | |
177 { | |
178 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
179 } | |
180 else | |
181 if( (use_alpha == true) && (use_beta == true) ) | |
182 { | |
183 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
184 } | |
185 | |
186 } | |
187 } | |
188 | |
189 } | |
190 } | |
191 | |
192 }; | |
193 | |
194 | |
195 | |
196 //! Matrix multplication where the matrices have different element types. | |
197 //! Simple version (no caching). | |
198 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
199 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
200 class gemm_mixed_small | |
201 { | |
202 public: | |
203 | |
204 template<typename out_eT, typename in_eT1, typename in_eT2> | |
205 arma_hot | |
206 inline | |
207 static | |
208 void | |
209 apply | |
210 ( | |
211 Mat<out_eT>& C, | |
212 const Mat<in_eT1>& A, | |
213 const Mat<in_eT2>& B, | |
214 const out_eT alpha = out_eT(1), | |
215 const out_eT beta = out_eT(0) | |
216 ) | |
217 { | |
218 arma_extra_debug_sigprint(); | |
219 | |
220 const uword A_n_rows = A.n_rows; | |
221 const uword A_n_cols = A.n_cols; | |
222 | |
223 const uword B_n_rows = B.n_rows; | |
224 const uword B_n_cols = B.n_cols; | |
225 | |
226 if( (do_trans_A == false) && (do_trans_B == false) ) | |
227 { | |
228 for(uword row_A = 0; row_A < A_n_rows; ++row_A) | |
229 { | |
230 for(uword col_B = 0; col_B < B_n_cols; ++col_B) | |
231 { | |
232 const in_eT2* B_coldata = B.colptr(col_B); | |
233 | |
234 out_eT acc = out_eT(0); | |
235 for(uword i = 0; i < B_n_rows; ++i) | |
236 { | |
237 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); | |
238 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
239 acc += val1 * val2; | |
240 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
241 } | |
242 | |
243 if( (use_alpha == false) && (use_beta == false) ) | |
244 { | |
245 C.at(row_A,col_B) = acc; | |
246 } | |
247 else | |
248 if( (use_alpha == true) && (use_beta == false) ) | |
249 { | |
250 C.at(row_A,col_B) = alpha * acc; | |
251 } | |
252 else | |
253 if( (use_alpha == false) && (use_beta == true) ) | |
254 { | |
255 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
256 } | |
257 else | |
258 if( (use_alpha == true) && (use_beta == true) ) | |
259 { | |
260 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
261 } | |
262 } | |
263 } | |
264 } | |
265 else | |
266 if( (do_trans_A == true) && (do_trans_B == false) ) | |
267 { | |
268 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
269 { | |
270 // col_A is interpreted as row_A when storing the results in matrix C | |
271 | |
272 const in_eT1* A_coldata = A.colptr(col_A); | |
273 | |
274 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
275 { | |
276 const in_eT2* B_coldata = B.colptr(col_B); | |
277 | |
278 out_eT acc = out_eT(0); | |
279 for(uword i=0; i < B_n_rows; ++i) | |
280 { | |
281 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
282 } | |
283 | |
284 if( (use_alpha == false) && (use_beta == false) ) | |
285 { | |
286 C.at(col_A,col_B) = acc; | |
287 } | |
288 else | |
289 if( (use_alpha == true) && (use_beta == false) ) | |
290 { | |
291 C.at(col_A,col_B) = alpha * acc; | |
292 } | |
293 else | |
294 if( (use_alpha == false) && (use_beta == true) ) | |
295 { | |
296 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
297 } | |
298 else | |
299 if( (use_alpha == true) && (use_beta == true) ) | |
300 { | |
301 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
302 } | |
303 | |
304 } | |
305 } | |
306 } | |
307 else | |
308 if( (do_trans_A == false) && (do_trans_B == true) ) | |
309 { | |
310 for(uword row_A = 0; row_A < A_n_rows; ++row_A) | |
311 { | |
312 for(uword row_B = 0; row_B < B_n_rows; ++row_B) | |
313 { | |
314 out_eT acc = out_eT(0); | |
315 for(uword i = 0; i < B_n_cols; ++i) | |
316 { | |
317 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); | |
318 } | |
319 | |
320 if( (use_alpha == false) && (use_beta == false) ) | |
321 { | |
322 C.at(row_A,row_B) = acc; | |
323 } | |
324 else | |
325 if( (use_alpha == true) && (use_beta == false) ) | |
326 { | |
327 C.at(row_A,row_B) = alpha * acc; | |
328 } | |
329 else | |
330 if( (use_alpha == false) && (use_beta == true) ) | |
331 { | |
332 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); | |
333 } | |
334 else | |
335 if( (use_alpha == true) && (use_beta == true) ) | |
336 { | |
337 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); | |
338 } | |
339 } | |
340 } | |
341 } | |
342 else | |
343 if( (do_trans_A == true) && (do_trans_B == true) ) | |
344 { | |
345 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
346 { | |
347 | |
348 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
349 { | |
350 const in_eT1* A_coldata = A.colptr(col_A); | |
351 | |
352 out_eT acc = out_eT(0); | |
353 for(uword i=0; i < A_n_rows; ++i) | |
354 { | |
355 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); | |
356 } | |
357 | |
358 if( (use_alpha == false) && (use_beta == false) ) | |
359 { | |
360 C.at(col_A,row_B) = acc; | |
361 } | |
362 else | |
363 if( (use_alpha == true) && (use_beta == false) ) | |
364 { | |
365 C.at(col_A,row_B) = alpha * acc; | |
366 } | |
367 else | |
368 if( (use_alpha == false) && (use_beta == true) ) | |
369 { | |
370 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
371 } | |
372 else | |
373 if( (use_alpha == true) && (use_beta == true) ) | |
374 { | |
375 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
376 } | |
377 | |
378 } | |
379 } | |
380 | |
381 } | |
382 } | |
383 | |
384 }; | |
385 | |
386 | |
387 | |
388 | |
389 | |
390 //! \brief | |
391 //! Matrix multplication where the matrices have differing element types. | |
392 | |
393 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
394 class gemm_mixed | |
395 { | |
396 public: | |
397 | |
398 //! immediate multiplication of matrices A and B, storing the result in C | |
399 template<typename out_eT, typename in_eT1, typename in_eT2> | |
400 inline | |
401 static | |
402 void | |
403 apply | |
404 ( | |
405 Mat<out_eT>& C, | |
406 const Mat<in_eT1>& A, | |
407 const Mat<in_eT2>& B, | |
408 const out_eT alpha = out_eT(1), | |
409 const out_eT beta = out_eT(0) | |
410 ) | |
411 { | |
412 arma_extra_debug_sigprint(); | |
413 | |
414 Mat<in_eT1> tmp_A; | |
415 Mat<in_eT2> tmp_B; | |
416 | |
417 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) ); | |
418 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) ); | |
419 | |
420 if(do_trans_A) | |
421 { | |
422 op_htrans::apply_noalias(tmp_A, A); | |
423 } | |
424 | |
425 if(do_trans_B) | |
426 { | |
427 op_htrans::apply_noalias(tmp_B, B); | |
428 } | |
429 | |
430 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A; | |
431 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B; | |
432 | |
433 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) ) | |
434 { | |
435 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
436 } | |
437 else | |
438 { | |
439 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
440 } | |
441 } | |
442 | |
443 | |
444 }; | |
445 | |
446 | |
447 | |
448 //! @} |