comparison armadillo-2.4.4/include/armadillo_bits/gemm_mixed.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12
13
14 //! \addtogroup gemm_mixed
15 //! @{
16
17
18
19 //! \brief
20 //! Matrix multplication where the matrices have differing element types.
21 //! Uses caching for speedup.
22 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
23
24 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
25 class gemm_mixed_large
26 {
27 public:
28
29 template<typename out_eT, typename in_eT1, typename in_eT2>
30 arma_hot
31 inline
32 static
33 void
34 apply
35 (
36 Mat<out_eT>& C,
37 const Mat<in_eT1>& A,
38 const Mat<in_eT2>& B,
39 const out_eT alpha = out_eT(1),
40 const out_eT beta = out_eT(0)
41 )
42 {
43 arma_extra_debug_sigprint();
44
45 const uword A_n_rows = A.n_rows;
46 const uword A_n_cols = A.n_cols;
47
48 const uword B_n_rows = B.n_rows;
49 const uword B_n_cols = B.n_cols;
50
51 if( (do_trans_A == false) && (do_trans_B == false) )
52 {
53 podarray<in_eT1> tmp(A_n_cols);
54 in_eT1* A_rowdata = tmp.memptr();
55
56 for(uword row_A=0; row_A < A_n_rows; ++row_A)
57 {
58 tmp.copy_row(A, row_A);
59
60 for(uword col_B=0; col_B < B_n_cols; ++col_B)
61 {
62 const in_eT2* B_coldata = B.colptr(col_B);
63
64 out_eT acc = out_eT(0);
65 for(uword i=0; i < B_n_rows; ++i)
66 {
67 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
68 }
69
70 if( (use_alpha == false) && (use_beta == false) )
71 {
72 C.at(row_A,col_B) = acc;
73 }
74 else
75 if( (use_alpha == true) && (use_beta == false) )
76 {
77 C.at(row_A,col_B) = alpha * acc;
78 }
79 else
80 if( (use_alpha == false) && (use_beta == true) )
81 {
82 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
83 }
84 else
85 if( (use_alpha == true) && (use_beta == true) )
86 {
87 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
88 }
89
90 }
91 }
92 }
93 else
94 if( (do_trans_A == true) && (do_trans_B == false) )
95 {
96 for(uword col_A=0; col_A < A_n_cols; ++col_A)
97 {
98 // col_A is interpreted as row_A when storing the results in matrix C
99
100 const in_eT1* A_coldata = A.colptr(col_A);
101
102 for(uword col_B=0; col_B < B_n_cols; ++col_B)
103 {
104 const in_eT2* B_coldata = B.colptr(col_B);
105
106 out_eT acc = out_eT(0);
107 for(uword i=0; i < B_n_rows; ++i)
108 {
109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
110 }
111
112 if( (use_alpha == false) && (use_beta == false) )
113 {
114 C.at(col_A,col_B) = acc;
115 }
116 else
117 if( (use_alpha == true) && (use_beta == false) )
118 {
119 C.at(col_A,col_B) = alpha * acc;
120 }
121 else
122 if( (use_alpha == false) && (use_beta == true) )
123 {
124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
125 }
126 else
127 if( (use_alpha == true) && (use_beta == true) )
128 {
129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
130 }
131
132 }
133 }
134 }
135 else
136 if( (do_trans_A == false) && (do_trans_B == true) )
137 {
138 Mat<in_eT2> B_tmp;
139
140 op_strans::apply_noalias(B_tmp, B);
141
142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
143 }
144 else
145 if( (do_trans_A == true) && (do_trans_B == true) )
146 {
147 // mat B_tmp = trans(B);
148 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
149
150
151 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
152 // transpose operations are not needed
153
154 podarray<in_eT2> tmp(B_n_cols);
155 in_eT2* B_rowdata = tmp.memptr();
156
157 for(uword row_B=0; row_B < B_n_rows; ++row_B)
158 {
159 tmp.copy_row(B, row_B);
160
161 for(uword col_A=0; col_A < A_n_cols; ++col_A)
162 {
163 const in_eT1* A_coldata = A.colptr(col_A);
164
165 out_eT acc = out_eT(0);
166 for(uword i=0; i < A_n_rows; ++i)
167 {
168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
169 }
170
171 if( (use_alpha == false) && (use_beta == false) )
172 {
173 C.at(col_A,row_B) = acc;
174 }
175 else
176 if( (use_alpha == true) && (use_beta == false) )
177 {
178 C.at(col_A,row_B) = alpha * acc;
179 }
180 else
181 if( (use_alpha == false) && (use_beta == true) )
182 {
183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
184 }
185 else
186 if( (use_alpha == true) && (use_beta == true) )
187 {
188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
189 }
190
191 }
192 }
193
194 }
195 }
196
197 };
198
199
200
201 //! Matrix multplication where the matrices have different element types.
202 //! Simple version (no caching).
203 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
205 class gemm_mixed_small
206 {
207 public:
208
209 template<typename out_eT, typename in_eT1, typename in_eT2>
210 arma_hot
211 inline
212 static
213 void
214 apply
215 (
216 Mat<out_eT>& C,
217 const Mat<in_eT1>& A,
218 const Mat<in_eT2>& B,
219 const out_eT alpha = out_eT(1),
220 const out_eT beta = out_eT(0)
221 )
222 {
223 arma_extra_debug_sigprint();
224
225 const uword A_n_rows = A.n_rows;
226 const uword A_n_cols = A.n_cols;
227
228 const uword B_n_rows = B.n_rows;
229 const uword B_n_cols = B.n_cols;
230
231 if( (do_trans_A == false) && (do_trans_B == false) )
232 {
233 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
234 {
235 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
236 {
237 const in_eT2* B_coldata = B.colptr(col_B);
238
239 out_eT acc = out_eT(0);
240 for(uword i = 0; i < B_n_rows; ++i)
241 {
242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
244 acc += val1 * val2;
245 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
246 }
247
248 if( (use_alpha == false) && (use_beta == false) )
249 {
250 C.at(row_A,col_B) = acc;
251 }
252 else
253 if( (use_alpha == true) && (use_beta == false) )
254 {
255 C.at(row_A,col_B) = alpha * acc;
256 }
257 else
258 if( (use_alpha == false) && (use_beta == true) )
259 {
260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
261 }
262 else
263 if( (use_alpha == true) && (use_beta == true) )
264 {
265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
266 }
267 }
268 }
269 }
270 else
271 if( (do_trans_A == true) && (do_trans_B == false) )
272 {
273 for(uword col_A=0; col_A < A_n_cols; ++col_A)
274 {
275 // col_A is interpreted as row_A when storing the results in matrix C
276
277 const in_eT1* A_coldata = A.colptr(col_A);
278
279 for(uword col_B=0; col_B < B_n_cols; ++col_B)
280 {
281 const in_eT2* B_coldata = B.colptr(col_B);
282
283 out_eT acc = out_eT(0);
284 for(uword i=0; i < B_n_rows; ++i)
285 {
286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
287 }
288
289 if( (use_alpha == false) && (use_beta == false) )
290 {
291 C.at(col_A,col_B) = acc;
292 }
293 else
294 if( (use_alpha == true) && (use_beta == false) )
295 {
296 C.at(col_A,col_B) = alpha * acc;
297 }
298 else
299 if( (use_alpha == false) && (use_beta == true) )
300 {
301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
302 }
303 else
304 if( (use_alpha == true) && (use_beta == true) )
305 {
306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
307 }
308
309 }
310 }
311 }
312 else
313 if( (do_trans_A == false) && (do_trans_B == true) )
314 {
315 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
316 {
317 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
318 {
319 out_eT acc = out_eT(0);
320 for(uword i = 0; i < B_n_cols; ++i)
321 {
322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
323 }
324
325 if( (use_alpha == false) && (use_beta == false) )
326 {
327 C.at(row_A,row_B) = acc;
328 }
329 else
330 if( (use_alpha == true) && (use_beta == false) )
331 {
332 C.at(row_A,row_B) = alpha * acc;
333 }
334 else
335 if( (use_alpha == false) && (use_beta == true) )
336 {
337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
338 }
339 else
340 if( (use_alpha == true) && (use_beta == true) )
341 {
342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
343 }
344 }
345 }
346 }
347 else
348 if( (do_trans_A == true) && (do_trans_B == true) )
349 {
350 for(uword row_B=0; row_B < B_n_rows; ++row_B)
351 {
352
353 for(uword col_A=0; col_A < A_n_cols; ++col_A)
354 {
355 const in_eT1* A_coldata = A.colptr(col_A);
356
357 out_eT acc = out_eT(0);
358 for(uword i=0; i < A_n_rows; ++i)
359 {
360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
361 }
362
363 if( (use_alpha == false) && (use_beta == false) )
364 {
365 C.at(col_A,row_B) = acc;
366 }
367 else
368 if( (use_alpha == true) && (use_beta == false) )
369 {
370 C.at(col_A,row_B) = alpha * acc;
371 }
372 else
373 if( (use_alpha == false) && (use_beta == true) )
374 {
375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
376 }
377 else
378 if( (use_alpha == true) && (use_beta == true) )
379 {
380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
381 }
382
383 }
384 }
385
386 }
387 }
388
389 };
390
391
392
393
394
395 //! \brief
396 //! Matrix multplication where the matrices have differing element types.
397
398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
399 class gemm_mixed
400 {
401 public:
402
403 //! immediate multiplication of matrices A and B, storing the result in C
404 template<typename out_eT, typename in_eT1, typename in_eT2>
405 inline
406 static
407 void
408 apply
409 (
410 Mat<out_eT>& C,
411 const Mat<in_eT1>& A,
412 const Mat<in_eT2>& B,
413 const out_eT alpha = out_eT(1),
414 const out_eT beta = out_eT(0)
415 )
416 {
417 arma_extra_debug_sigprint();
418
419 Mat<in_eT1> tmp_A;
420 Mat<in_eT2> tmp_B;
421
422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
424
425 if(do_trans_A)
426 {
427 op_htrans::apply_noalias(tmp_A, A);
428 }
429
430 if(do_trans_B)
431 {
432 op_htrans::apply_noalias(tmp_B, B);
433 }
434
435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
437
438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
439 {
440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
441 }
442 else
443 {
444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
445 }
446 }
447
448
449 };
450
451
452
453 //! @}