max@0
|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
|
max@0
|
2 // Copyright (C) 2008-2011 Conrad Sanderson
|
max@0
|
3 //
|
max@0
|
4 // This file is part of the Armadillo C++ library.
|
max@0
|
5 // It is provided without any warranty of fitness
|
max@0
|
6 // for any purpose. You can redistribute this file
|
max@0
|
7 // and/or modify it under the terms of the GNU
|
max@0
|
8 // Lesser General Public License (LGPL) as published
|
max@0
|
9 // by the Free Software Foundation, either version 3
|
max@0
|
10 // of the License or (at your option) any later version.
|
max@0
|
11 // (see http://www.opensource.org/licenses for more info)
|
max@0
|
12
|
max@0
|
13
|
max@0
|
14 //! \addtogroup gemm_mixed
|
max@0
|
15 //! @{
|
max@0
|
16
|
max@0
|
17
|
max@0
|
18
|
max@0
|
19 //! \brief
|
max@0
|
20 //! Matrix multplication where the matrices have differing element types.
|
max@0
|
21 //! Uses caching for speedup.
|
max@0
|
22 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
|
max@0
|
23
|
max@0
|
24 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
25 class gemm_mixed_large
|
max@0
|
26 {
|
max@0
|
27 public:
|
max@0
|
28
|
max@0
|
29 template<typename out_eT, typename in_eT1, typename in_eT2>
|
max@0
|
30 arma_hot
|
max@0
|
31 inline
|
max@0
|
32 static
|
max@0
|
33 void
|
max@0
|
34 apply
|
max@0
|
35 (
|
max@0
|
36 Mat<out_eT>& C,
|
max@0
|
37 const Mat<in_eT1>& A,
|
max@0
|
38 const Mat<in_eT2>& B,
|
max@0
|
39 const out_eT alpha = out_eT(1),
|
max@0
|
40 const out_eT beta = out_eT(0)
|
max@0
|
41 )
|
max@0
|
42 {
|
max@0
|
43 arma_extra_debug_sigprint();
|
max@0
|
44
|
max@0
|
45 const uword A_n_rows = A.n_rows;
|
max@0
|
46 const uword A_n_cols = A.n_cols;
|
max@0
|
47
|
max@0
|
48 const uword B_n_rows = B.n_rows;
|
max@0
|
49 const uword B_n_cols = B.n_cols;
|
max@0
|
50
|
max@0
|
51 if( (do_trans_A == false) && (do_trans_B == false) )
|
max@0
|
52 {
|
max@0
|
53 podarray<in_eT1> tmp(A_n_cols);
|
max@0
|
54 in_eT1* A_rowdata = tmp.memptr();
|
max@0
|
55
|
max@0
|
56 for(uword row_A=0; row_A < A_n_rows; ++row_A)
|
max@0
|
57 {
|
max@0
|
58 tmp.copy_row(A, row_A);
|
max@0
|
59
|
max@0
|
60 for(uword col_B=0; col_B < B_n_cols; ++col_B)
|
max@0
|
61 {
|
max@0
|
62 const in_eT2* B_coldata = B.colptr(col_B);
|
max@0
|
63
|
max@0
|
64 out_eT acc = out_eT(0);
|
max@0
|
65 for(uword i=0; i < B_n_rows; ++i)
|
max@0
|
66 {
|
max@0
|
67 acc += upgrade_val<in_eT1,in_eT2>::apply(A_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
|
max@0
|
68 }
|
max@0
|
69
|
max@0
|
70 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
71 {
|
max@0
|
72 C.at(row_A,col_B) = acc;
|
max@0
|
73 }
|
max@0
|
74 else
|
max@0
|
75 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
76 {
|
max@0
|
77 C.at(row_A,col_B) = alpha * acc;
|
max@0
|
78 }
|
max@0
|
79 else
|
max@0
|
80 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
81 {
|
max@0
|
82 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
|
max@0
|
83 }
|
max@0
|
84 else
|
max@0
|
85 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
86 {
|
max@0
|
87 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
|
max@0
|
88 }
|
max@0
|
89
|
max@0
|
90 }
|
max@0
|
91 }
|
max@0
|
92 }
|
max@0
|
93 else
|
max@0
|
94 if( (do_trans_A == true) && (do_trans_B == false) )
|
max@0
|
95 {
|
max@0
|
96 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
max@0
|
97 {
|
max@0
|
98 // col_A is interpreted as row_A when storing the results in matrix C
|
max@0
|
99
|
max@0
|
100 const in_eT1* A_coldata = A.colptr(col_A);
|
max@0
|
101
|
max@0
|
102 for(uword col_B=0; col_B < B_n_cols; ++col_B)
|
max@0
|
103 {
|
max@0
|
104 const in_eT2* B_coldata = B.colptr(col_B);
|
max@0
|
105
|
max@0
|
106 out_eT acc = out_eT(0);
|
max@0
|
107 for(uword i=0; i < B_n_rows; ++i)
|
max@0
|
108 {
|
max@0
|
109 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
|
max@0
|
110 }
|
max@0
|
111
|
max@0
|
112 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
113 {
|
max@0
|
114 C.at(col_A,col_B) = acc;
|
max@0
|
115 }
|
max@0
|
116 else
|
max@0
|
117 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
118 {
|
max@0
|
119 C.at(col_A,col_B) = alpha * acc;
|
max@0
|
120 }
|
max@0
|
121 else
|
max@0
|
122 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
123 {
|
max@0
|
124 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
|
max@0
|
125 }
|
max@0
|
126 else
|
max@0
|
127 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
128 {
|
max@0
|
129 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
|
max@0
|
130 }
|
max@0
|
131
|
max@0
|
132 }
|
max@0
|
133 }
|
max@0
|
134 }
|
max@0
|
135 else
|
max@0
|
136 if( (do_trans_A == false) && (do_trans_B == true) )
|
max@0
|
137 {
|
max@0
|
138 Mat<in_eT2> B_tmp;
|
max@0
|
139
|
max@0
|
140 op_strans::apply_noalias(B_tmp, B);
|
max@0
|
141
|
max@0
|
142 gemm_mixed_large<false, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
|
max@0
|
143 }
|
max@0
|
144 else
|
max@0
|
145 if( (do_trans_A == true) && (do_trans_B == true) )
|
max@0
|
146 {
|
max@0
|
147 // mat B_tmp = trans(B);
|
max@0
|
148 // dgemm_arma<true, false, use_alpha, use_beta>::apply(C, A, B_tmp, alpha, beta);
|
max@0
|
149
|
max@0
|
150
|
max@0
|
151 // By using the trans(A)*trans(B) = trans(B*A) equivalency,
|
max@0
|
152 // transpose operations are not needed
|
max@0
|
153
|
max@0
|
154 podarray<in_eT2> tmp(B_n_cols);
|
max@0
|
155 in_eT2* B_rowdata = tmp.memptr();
|
max@0
|
156
|
max@0
|
157 for(uword row_B=0; row_B < B_n_rows; ++row_B)
|
max@0
|
158 {
|
max@0
|
159 tmp.copy_row(B, row_B);
|
max@0
|
160
|
max@0
|
161 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
max@0
|
162 {
|
max@0
|
163 const in_eT1* A_coldata = A.colptr(col_A);
|
max@0
|
164
|
max@0
|
165 out_eT acc = out_eT(0);
|
max@0
|
166 for(uword i=0; i < A_n_rows; ++i)
|
max@0
|
167 {
|
max@0
|
168 acc += upgrade_val<in_eT1,in_eT2>::apply(B_rowdata[i]) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
|
max@0
|
169 }
|
max@0
|
170
|
max@0
|
171 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
172 {
|
max@0
|
173 C.at(col_A,row_B) = acc;
|
max@0
|
174 }
|
max@0
|
175 else
|
max@0
|
176 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
177 {
|
max@0
|
178 C.at(col_A,row_B) = alpha * acc;
|
max@0
|
179 }
|
max@0
|
180 else
|
max@0
|
181 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
182 {
|
max@0
|
183 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
|
max@0
|
184 }
|
max@0
|
185 else
|
max@0
|
186 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
187 {
|
max@0
|
188 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
|
max@0
|
189 }
|
max@0
|
190
|
max@0
|
191 }
|
max@0
|
192 }
|
max@0
|
193
|
max@0
|
194 }
|
max@0
|
195 }
|
max@0
|
196
|
max@0
|
197 };
|
max@0
|
198
|
max@0
|
199
|
max@0
|
200
|
max@0
|
201 //! Matrix multplication where the matrices have different element types.
|
max@0
|
202 //! Simple version (no caching).
|
max@0
|
203 //! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes)
|
max@0
|
204 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
205 class gemm_mixed_small
|
max@0
|
206 {
|
max@0
|
207 public:
|
max@0
|
208
|
max@0
|
209 template<typename out_eT, typename in_eT1, typename in_eT2>
|
max@0
|
210 arma_hot
|
max@0
|
211 inline
|
max@0
|
212 static
|
max@0
|
213 void
|
max@0
|
214 apply
|
max@0
|
215 (
|
max@0
|
216 Mat<out_eT>& C,
|
max@0
|
217 const Mat<in_eT1>& A,
|
max@0
|
218 const Mat<in_eT2>& B,
|
max@0
|
219 const out_eT alpha = out_eT(1),
|
max@0
|
220 const out_eT beta = out_eT(0)
|
max@0
|
221 )
|
max@0
|
222 {
|
max@0
|
223 arma_extra_debug_sigprint();
|
max@0
|
224
|
max@0
|
225 const uword A_n_rows = A.n_rows;
|
max@0
|
226 const uword A_n_cols = A.n_cols;
|
max@0
|
227
|
max@0
|
228 const uword B_n_rows = B.n_rows;
|
max@0
|
229 const uword B_n_cols = B.n_cols;
|
max@0
|
230
|
max@0
|
231 if( (do_trans_A == false) && (do_trans_B == false) )
|
max@0
|
232 {
|
max@0
|
233 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
|
max@0
|
234 {
|
max@0
|
235 for(uword col_B = 0; col_B < B_n_cols; ++col_B)
|
max@0
|
236 {
|
max@0
|
237 const in_eT2* B_coldata = B.colptr(col_B);
|
max@0
|
238
|
max@0
|
239 out_eT acc = out_eT(0);
|
max@0
|
240 for(uword i = 0; i < B_n_rows; ++i)
|
max@0
|
241 {
|
max@0
|
242 const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
|
max@0
|
243 const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
|
max@0
|
244 acc += val1 * val2;
|
max@0
|
245 //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
|
max@0
|
246 }
|
max@0
|
247
|
max@0
|
248 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
249 {
|
max@0
|
250 C.at(row_A,col_B) = acc;
|
max@0
|
251 }
|
max@0
|
252 else
|
max@0
|
253 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
254 {
|
max@0
|
255 C.at(row_A,col_B) = alpha * acc;
|
max@0
|
256 }
|
max@0
|
257 else
|
max@0
|
258 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
259 {
|
max@0
|
260 C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
|
max@0
|
261 }
|
max@0
|
262 else
|
max@0
|
263 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
264 {
|
max@0
|
265 C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
|
max@0
|
266 }
|
max@0
|
267 }
|
max@0
|
268 }
|
max@0
|
269 }
|
max@0
|
270 else
|
max@0
|
271 if( (do_trans_A == true) && (do_trans_B == false) )
|
max@0
|
272 {
|
max@0
|
273 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
max@0
|
274 {
|
max@0
|
275 // col_A is interpreted as row_A when storing the results in matrix C
|
max@0
|
276
|
max@0
|
277 const in_eT1* A_coldata = A.colptr(col_A);
|
max@0
|
278
|
max@0
|
279 for(uword col_B=0; col_B < B_n_cols; ++col_B)
|
max@0
|
280 {
|
max@0
|
281 const in_eT2* B_coldata = B.colptr(col_B);
|
max@0
|
282
|
max@0
|
283 out_eT acc = out_eT(0);
|
max@0
|
284 for(uword i=0; i < B_n_rows; ++i)
|
max@0
|
285 {
|
max@0
|
286 acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
|
max@0
|
287 }
|
max@0
|
288
|
max@0
|
289 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
290 {
|
max@0
|
291 C.at(col_A,col_B) = acc;
|
max@0
|
292 }
|
max@0
|
293 else
|
max@0
|
294 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
295 {
|
max@0
|
296 C.at(col_A,col_B) = alpha * acc;
|
max@0
|
297 }
|
max@0
|
298 else
|
max@0
|
299 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
300 {
|
max@0
|
301 C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
|
max@0
|
302 }
|
max@0
|
303 else
|
max@0
|
304 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
305 {
|
max@0
|
306 C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
|
max@0
|
307 }
|
max@0
|
308
|
max@0
|
309 }
|
max@0
|
310 }
|
max@0
|
311 }
|
max@0
|
312 else
|
max@0
|
313 if( (do_trans_A == false) && (do_trans_B == true) )
|
max@0
|
314 {
|
max@0
|
315 for(uword row_A = 0; row_A < A_n_rows; ++row_A)
|
max@0
|
316 {
|
max@0
|
317 for(uword row_B = 0; row_B < B_n_rows; ++row_B)
|
max@0
|
318 {
|
max@0
|
319 out_eT acc = out_eT(0);
|
max@0
|
320 for(uword i = 0; i < B_n_cols; ++i)
|
max@0
|
321 {
|
max@0
|
322 acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
|
max@0
|
323 }
|
max@0
|
324
|
max@0
|
325 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
326 {
|
max@0
|
327 C.at(row_A,row_B) = acc;
|
max@0
|
328 }
|
max@0
|
329 else
|
max@0
|
330 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
331 {
|
max@0
|
332 C.at(row_A,row_B) = alpha * acc;
|
max@0
|
333 }
|
max@0
|
334 else
|
max@0
|
335 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
336 {
|
max@0
|
337 C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
|
max@0
|
338 }
|
max@0
|
339 else
|
max@0
|
340 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
341 {
|
max@0
|
342 C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
|
max@0
|
343 }
|
max@0
|
344 }
|
max@0
|
345 }
|
max@0
|
346 }
|
max@0
|
347 else
|
max@0
|
348 if( (do_trans_A == true) && (do_trans_B == true) )
|
max@0
|
349 {
|
max@0
|
350 for(uword row_B=0; row_B < B_n_rows; ++row_B)
|
max@0
|
351 {
|
max@0
|
352
|
max@0
|
353 for(uword col_A=0; col_A < A_n_cols; ++col_A)
|
max@0
|
354 {
|
max@0
|
355 const in_eT1* A_coldata = A.colptr(col_A);
|
max@0
|
356
|
max@0
|
357 out_eT acc = out_eT(0);
|
max@0
|
358 for(uword i=0; i < A_n_rows; ++i)
|
max@0
|
359 {
|
max@0
|
360 acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
|
max@0
|
361 }
|
max@0
|
362
|
max@0
|
363 if( (use_alpha == false) && (use_beta == false) )
|
max@0
|
364 {
|
max@0
|
365 C.at(col_A,row_B) = acc;
|
max@0
|
366 }
|
max@0
|
367 else
|
max@0
|
368 if( (use_alpha == true) && (use_beta == false) )
|
max@0
|
369 {
|
max@0
|
370 C.at(col_A,row_B) = alpha * acc;
|
max@0
|
371 }
|
max@0
|
372 else
|
max@0
|
373 if( (use_alpha == false) && (use_beta == true) )
|
max@0
|
374 {
|
max@0
|
375 C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
|
max@0
|
376 }
|
max@0
|
377 else
|
max@0
|
378 if( (use_alpha == true) && (use_beta == true) )
|
max@0
|
379 {
|
max@0
|
380 C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
|
max@0
|
381 }
|
max@0
|
382
|
max@0
|
383 }
|
max@0
|
384 }
|
max@0
|
385
|
max@0
|
386 }
|
max@0
|
387 }
|
max@0
|
388
|
max@0
|
389 };
|
max@0
|
390
|
max@0
|
391
|
max@0
|
392
|
max@0
|
393
|
max@0
|
394
|
max@0
|
395 //! \brief
|
max@0
|
396 //! Matrix multplication where the matrices have differing element types.
|
max@0
|
397
|
max@0
|
398 template<const bool do_trans_A=false, const bool do_trans_B=false, const bool use_alpha=false, const bool use_beta=false>
|
max@0
|
399 class gemm_mixed
|
max@0
|
400 {
|
max@0
|
401 public:
|
max@0
|
402
|
max@0
|
403 //! immediate multiplication of matrices A and B, storing the result in C
|
max@0
|
404 template<typename out_eT, typename in_eT1, typename in_eT2>
|
max@0
|
405 inline
|
max@0
|
406 static
|
max@0
|
407 void
|
max@0
|
408 apply
|
max@0
|
409 (
|
max@0
|
410 Mat<out_eT>& C,
|
max@0
|
411 const Mat<in_eT1>& A,
|
max@0
|
412 const Mat<in_eT2>& B,
|
max@0
|
413 const out_eT alpha = out_eT(1),
|
max@0
|
414 const out_eT beta = out_eT(0)
|
max@0
|
415 )
|
max@0
|
416 {
|
max@0
|
417 arma_extra_debug_sigprint();
|
max@0
|
418
|
max@0
|
419 Mat<in_eT1> tmp_A;
|
max@0
|
420 Mat<in_eT2> tmp_B;
|
max@0
|
421
|
max@0
|
422 const bool predo_trans_A = ( (do_trans_A == true) && (is_complex<in_eT1>::value == true) );
|
max@0
|
423 const bool predo_trans_B = ( (do_trans_B == true) && (is_complex<in_eT2>::value == true) );
|
max@0
|
424
|
max@0
|
425 if(do_trans_A)
|
max@0
|
426 {
|
max@0
|
427 op_htrans::apply_noalias(tmp_A, A);
|
max@0
|
428 }
|
max@0
|
429
|
max@0
|
430 if(do_trans_B)
|
max@0
|
431 {
|
max@0
|
432 op_htrans::apply_noalias(tmp_B, B);
|
max@0
|
433 }
|
max@0
|
434
|
max@0
|
435 const Mat<in_eT1>& AA = (predo_trans_A == false) ? A : tmp_A;
|
max@0
|
436 const Mat<in_eT2>& BB = (predo_trans_B == false) ? B : tmp_B;
|
max@0
|
437
|
max@0
|
438 if( (AA.n_elem <= 64u) && (BB.n_elem <= 64u) )
|
max@0
|
439 {
|
max@0
|
440 gemm_mixed_small<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
|
max@0
|
441 }
|
max@0
|
442 else
|
max@0
|
443 {
|
max@0
|
444 gemm_mixed_large<((predo_trans_A) ? false : do_trans_A), ((predo_trans_B) ? false : do_trans_B), use_alpha, use_beta>::apply(C, AA, BB, alpha, beta);
|
max@0
|
445 }
|
max@0
|
446 }
|
max@0
|
447
|
max@0
|
448
|
max@0
|
449 };
|
max@0
|
450
|
max@0
|
451
|
max@0
|
452
|
max@0
|
453 //! @}
|