Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/gemm_mixed.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup gemm_mixed | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 //! \brief | |
20 //! Matrix multplication where the matrices have differing element types. | |
21 //! Uses caching for speedup. | |
22 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
23 | |
24 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
25 class gemm_mixed_large | |
26 { | |
27 public: | |
28 | |
29 template<typename out_eT, typename in_eT1, typename in_eT2> | |
30 arma_hot | |
31 inline | |
32 static | |
33 void | |
34 apply | |
35 ( | |
36 Mat<out_eT>& C, | |
37 const Mat<in_eT1>& A, | |
38 const Mat<in_eT2>& B, | |
39 const out_eT alpha = out_eT(1), | |
40 const out_eT beta = out_eT(0) | |
41 ) | |
42 { | |
43 arma_extra_debug_sigprint(); | |
44 | |
45 const uword A_n_rows = A.n_rows; | |
46 const uword A_n_cols = A.n_cols; | |
47 | |
48 const uword B_n_rows = B.n_rows; | |
49 const uword B_n_cols = B.n_cols; | |
50 | |
51 if( (do_trans_A == false) && (do_trans_B == false) ) | |
52 { | |
53 podarray<in_eT1> tmp(A_n_cols); | |
54 in_eT1* A_rowdata = tmp.memptr(); | |
55 | |
56 for(uword row_A=0; row_A < A_n_rows; ++row_A) | |
57 { | |
58 tmp.copy_row(A, row_A); | |
59 | |
60 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
61 { | |
62 const in_eT2* B_coldata = B.colptr(col_B); | |
63 | |
64 out_eT acc = out_eT(0); | |
65 for(uword i=0; i < B_n_rows; ++i) | |
66 { | |
67 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
68 } | |
69 | |
70 if( (use_alpha == false) && (use_beta == false) ) | |
71 { | |
72 C.at(row_A,col_B) = acc; | |
73 } | |
74 else | |
75 if( (use_alpha == true) && (use_beta == false) ) | |
76 { | |
77 C.at(row_A,col_B) = alpha * acc; | |
78 } | |
79 else | |
80 if( (use_alpha == false) && (use_beta == true) ) | |
81 { | |
82 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
83 } | |
84 else | |
85 if( (use_alpha == true) && (use_beta == true) ) | |
86 { | |
87 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
88 } | |
89 | |
90 } | |
91 } | |
92 } | |
93 else | |
94 if( (do_trans_A == true) && (do_trans_B == false) ) | |
95 { | |
96 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
97 { | |
98 // col_A is interpreted as row_A when storing the results in matrix C | |
99 | |
100 const in_eT1* A_coldata = A.colptr(col_A); | |
101 | |
102 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
103 { | |
104 const in_eT2* B_coldata = B.colptr(col_B); | |
105 | |
106 out_eT acc = out_eT(0); | |
107 for(uword i=0; i < B_n_rows; ++i) | |
108 { | |
109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
110 } | |
111 | |
112 if( (use_alpha == false) && (use_beta == false) ) | |
113 { | |
114 C.at(col_A,col_B) = acc; | |
115 } | |
116 else | |
117 if( (use_alpha == true) && (use_beta == false) ) | |
118 { | |
119 C.at(col_A,col_B) = alpha * acc; | |
120 } | |
121 else | |
122 if( (use_alpha == false) && (use_beta == true) ) | |
123 { | |
124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
125 } | |
126 else | |
127 if( (use_alpha == true) && (use_beta == true) ) | |
128 { | |
129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
130 } | |
131 | |
132 } | |
133 } | |
134 } | |
135 else | |
136 if( (do_trans_A == false) && (do_trans_B == true) ) | |
137 { | |
138 Mat<in_eT2> B_tmp; | |
139 | |
140 op_strans::apply_noalias(B_tmp, B); | |
141 | |
142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
143 } | |
144 else | |
145 if( (do_trans_A == true) && (do_trans_B == true) ) | |
146 { | |
147 // mat B_tmp = trans(B); | |
148 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta); | |
149 | |
150 | |
151 // By using the trans(A)*trans(B) = trans(B*A) equivalency, | |
152 // transpose operations are not needed | |
153 | |
154 podarray<in_eT2> tmp(B_n_cols); | |
155 in_eT2* B_rowdata = tmp.memptr(); | |
156 | |
157 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
158 { | |
159 tmp.copy_row(B, row_B); | |
160 | |
161 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
162 { | |
163 const in_eT1* A_coldata = A.colptr(col_A); | |
164 | |
165 out_eT acc = out_eT(0); | |
166 for(uword i=0; i < A_n_rows; ++i) | |
167 { | |
168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); | |
169 } | |
170 | |
171 if( (use_alpha == false) && (use_beta == false) ) | |
172 { | |
173 C.at(col_A,row_B) = acc; | |
174 } | |
175 else | |
176 if( (use_alpha == true) && (use_beta == false) ) | |
177 { | |
178 C.at(col_A,row_B) = alpha * acc; | |
179 } | |
180 else | |
181 if( (use_alpha == false) && (use_beta == true) ) | |
182 { | |
183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
184 } | |
185 else | |
186 if( (use_alpha == true) && (use_beta == true) ) | |
187 { | |
188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
189 } | |
190 | |
191 } | |
192 } | |
193 | |
194 } | |
195 } | |
196 | |
197 }; | |
198 | |
199 | |
200 | |
201 //! Matrix multplication where the matrices have different element types. | |
202 //! Simple version (no caching). | |
203 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) | |
204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
205 class gemm_mixed_small | |
206 { | |
207 public: | |
208 | |
209 template<typename out_eT, typename in_eT1, typename in_eT2> | |
210 arma_hot | |
211 inline | |
212 static | |
213 void | |
214 apply | |
215 ( | |
216 Mat<out_eT>& C, | |
217 const Mat<in_eT1>& A, | |
218 const Mat<in_eT2>& B, | |
219 const out_eT alpha = out_eT(1), | |
220 const out_eT beta = out_eT(0) | |
221 ) | |
222 { | |
223 arma_extra_debug_sigprint(); | |
224 | |
225 const uword A_n_rows = A.n_rows; | |
226 const uword A_n_cols = A.n_cols; | |
227 | |
228 const uword B_n_rows = B.n_rows; | |
229 const uword B_n_cols = B.n_cols; | |
230 | |
231 if( (do_trans_A == false) && (do_trans_B == false) ) | |
232 { | |
233 for(uword row_A = 0; row_A < A_n_rows; ++row_A) | |
234 { | |
235 for(uword col_B = 0; col_B < B_n_cols; ++col_B) | |
236 { | |
237 const in_eT2* B_coldata = B.colptr(col_B); | |
238 | |
239 out_eT acc = out_eT(0); | |
240 for(uword i = 0; i < B_n_rows; ++i) | |
241 { | |
242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)); | |
243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
244 acc += val1 * val2; | |
245 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
246 } | |
247 | |
248 if( (use_alpha == false) && (use_beta == false) ) | |
249 { | |
250 C.at(row_A,col_B) = acc; | |
251 } | |
252 else | |
253 if( (use_alpha == true) && (use_beta == false) ) | |
254 { | |
255 C.at(row_A,col_B) = alpha * acc; | |
256 } | |
257 else | |
258 if( (use_alpha == false) && (use_beta == true) ) | |
259 { | |
260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B); | |
261 } | |
262 else | |
263 if( (use_alpha == true) && (use_beta == true) ) | |
264 { | |
265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B); | |
266 } | |
267 } | |
268 } | |
269 } | |
270 else | |
271 if( (do_trans_A == true) && (do_trans_B == false) ) | |
272 { | |
273 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
274 { | |
275 // col_A is interpreted as row_A when storing the results in matrix C | |
276 | |
277 const in_eT1* A_coldata = A.colptr(col_A); | |
278 | |
279 for(uword col_B=0; col_B < B_n_cols; ++col_B) | |
280 { | |
281 const in_eT2* B_coldata = B.colptr(col_B); | |
282 | |
283 out_eT acc = out_eT(0); | |
284 for(uword i=0; i < B_n_rows; ++i) | |
285 { | |
286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]); | |
287 } | |
288 | |
289 if( (use_alpha == false) && (use_beta == false) ) | |
290 { | |
291 C.at(col_A,col_B) = acc; | |
292 } | |
293 else | |
294 if( (use_alpha == true) && (use_beta == false) ) | |
295 { | |
296 C.at(col_A,col_B) = alpha * acc; | |
297 } | |
298 else | |
299 if( (use_alpha == false) && (use_beta == true) ) | |
300 { | |
301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B); | |
302 } | |
303 else | |
304 if( (use_alpha == true) && (use_beta == true) ) | |
305 { | |
306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B); | |
307 } | |
308 | |
309 } | |
310 } | |
311 } | |
312 else | |
313 if( (do_trans_A == false) && (do_trans_B == true) ) | |
314 { | |
315 for(uword row_A = 0; row_A < A_n_rows; ++row_A) | |
316 { | |
317 for(uword row_B = 0; row_B < B_n_rows; ++row_B) | |
318 { | |
319 out_eT acc = out_eT(0); | |
320 for(uword i = 0; i < B_n_cols; ++i) | |
321 { | |
322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)); | |
323 } | |
324 | |
325 if( (use_alpha == false) && (use_beta == false) ) | |
326 { | |
327 C.at(row_A,row_B) = acc; | |
328 } | |
329 else | |
330 if( (use_alpha == true) && (use_beta == false) ) | |
331 { | |
332 C.at(row_A,row_B) = alpha * acc; | |
333 } | |
334 else | |
335 if( (use_alpha == false) && (use_beta == true) ) | |
336 { | |
337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B); | |
338 } | |
339 else | |
340 if( (use_alpha == true) && (use_beta == true) ) | |
341 { | |
342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B); | |
343 } | |
344 } | |
345 } | |
346 } | |
347 else | |
348 if( (do_trans_A == true) && (do_trans_B == true) ) | |
349 { | |
350 for(uword row_B=0; row_B < B_n_rows; ++row_B) | |
351 { | |
352 | |
353 for(uword col_A=0; col_A < A_n_cols; ++col_A) | |
354 { | |
355 const in_eT1* A_coldata = A.colptr(col_A); | |
356 | |
357 out_eT acc = out_eT(0); | |
358 for(uword i=0; i < A_n_rows; ++i) | |
359 { | |
360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]); | |
361 } | |
362 | |
363 if( (use_alpha == false) && (use_beta == false) ) | |
364 { | |
365 C.at(col_A,row_B) = acc; | |
366 } | |
367 else | |
368 if( (use_alpha == true) && (use_beta == false) ) | |
369 { | |
370 C.at(col_A,row_B) = alpha * acc; | |
371 } | |
372 else | |
373 if( (use_alpha == false) && (use_beta == true) ) | |
374 { | |
375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B); | |
376 } | |
377 else | |
378 if( (use_alpha == true) && (use_beta == true) ) | |
379 { | |
380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B); | |
381 } | |
382 | |
383 } | |
384 } | |
385 | |
386 } | |
387 } | |
388 | |
389 }; | |
390 | |
391 | |
392 | |
393 | |
394 | |
395 //! \brief | |
396 //! Matrix multplication where the matrices have differing element types. | |
397 | |
398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false> | |
399 class gemm_mixed | |
400 { | |
401 public: | |
402 | |
403 //! immediate multiplication of matrices A and B, storing the result in C | |
404 template<typename out_eT, typename in_eT1, typename in_eT2> | |
405 inline | |
406 static | |
407 void | |
408 apply | |
409 ( | |
410 Mat<out_eT>& C, | |
411 const Mat<in_eT1>& A, | |
412 const Mat<in_eT2>& B, | |
413 const out_eT alpha = out_eT(1), | |
414 const out_eT beta = out_eT(0) | |
415 ) | |
416 { | |
417 arma_extra_debug_sigprint(); | |
418 | |
419 Mat<in_eT1> tmp_A; | |
420 Mat<in_eT2> tmp_B; | |
421 | |
422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) ); | |
423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) ); | |
424 | |
425 if(do_trans_A) | |
426 { | |
427 op_htrans::apply_noalias(tmp_A, A); | |
428 } | |
429 | |
430 if(do_trans_B) | |
431 { | |
432 op_htrans::apply_noalias(tmp_B, B); | |
433 } | |
434 | |
435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A; | |
436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B; | |
437 | |
438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) ) | |
439 { | |
440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
441 } | |
442 else | |
443 { | |
444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta); | |
445 } | |
446 } | |
447 | |
448 | |
449 }; | |
450 | |
451 | |
452 | |
453 //! @} |