Mercurial > hg > segmenter-vamp-plugin
comparison armadillo-2.4.4/include/armadillo_bits/gemv.hpp @ 0:8b6102e2a9b0
Armadillo Library
author | maxzanoni76 <max.zanoni@eecs.qmul.ac.uk> |
---|---|
date | Wed, 11 Apr 2012 09:27:06 +0100 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:8b6102e2a9b0 |
---|---|
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) | |
2 // Copyright (C) 2008-2011 Conrad Sanderson | |
3 // | |
4 // This file is part of the Armadillo C++ library. | |
5 // It is provided without any warranty of fitness | |
6 // for any purpose. You can redistribute this file | |
7 // and/or modify it under the terms of the GNU | |
8 // Lesser General Public License (LGPL) as published | |
9 // by the Free Software Foundation, either version 3 | |
10 // of the License or (at your option) any later version. | |
11 // (see http://www.opensource.org/licenses for more info) | |
12 | |
13 | |
14 //! \addtogroup gemv | |
15 //! @{ | |
16 | |
17 | |
18 | |
19 //! for tiny square matrices, size <= 4x4 | |
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
21 class gemv_emul_tinysq | |
22 { | |
23 public: | |
24 | |
25 | |
26 template<const uword row, const uword col> | |
27 struct pos | |
28 { | |
29 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2); | |
30 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3); | |
31 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4); | |
32 }; | |
33 | |
34 | |
35 | |
36 template<typename eT, const uword i> | |
37 arma_hot | |
38 arma_inline | |
39 static | |
40 void | |
41 assign(eT* y, const eT acc, const eT alpha, const eT beta) | |
42 { | |
43 if(use_beta == false) | |
44 { | |
45 y[i] = (use_alpha == false) ? acc : alpha*acc; | |
46 } | |
47 else | |
48 { | |
49 const eT tmp = y[i]; | |
50 | |
51 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc ); | |
52 } | |
53 } | |
54 | |
55 | |
56 | |
57 template<typename eT> | |
58 arma_hot | |
59 inline | |
60 static | |
61 void | |
62 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
63 { | |
64 arma_extra_debug_sigprint(); | |
65 | |
66 const eT* Am = A.memptr(); | |
67 | |
68 switch(A.n_rows) | |
69 { | |
70 case 1: | |
71 { | |
72 const eT acc = Am[0] * x[0]; | |
73 | |
74 assign<eT, 0>(y, acc, alpha, beta); | |
75 } | |
76 break; | |
77 | |
78 | |
79 case 2: | |
80 { | |
81 const eT x0 = x[0]; | |
82 const eT x1 = x[1]; | |
83 | |
84 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1; | |
85 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1; | |
86 | |
87 assign<eT, 0>(y, acc0, alpha, beta); | |
88 assign<eT, 1>(y, acc1, alpha, beta); | |
89 } | |
90 break; | |
91 | |
92 | |
93 case 3: | |
94 { | |
95 const eT x0 = x[0]; | |
96 const eT x1 = x[1]; | |
97 const eT x2 = x[2]; | |
98 | |
99 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2; | |
100 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2; | |
101 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2; | |
102 | |
103 assign<eT, 0>(y, acc0, alpha, beta); | |
104 assign<eT, 1>(y, acc1, alpha, beta); | |
105 assign<eT, 2>(y, acc2, alpha, beta); | |
106 } | |
107 break; | |
108 | |
109 | |
110 case 4: | |
111 { | |
112 const eT x0 = x[0]; | |
113 const eT x1 = x[1]; | |
114 const eT x2 = x[2]; | |
115 const eT x3 = x[3]; | |
116 | |
117 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3; | |
118 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3; | |
119 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3; | |
120 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3; | |
121 | |
122 assign<eT, 0>(y, acc0, alpha, beta); | |
123 assign<eT, 1>(y, acc1, alpha, beta); | |
124 assign<eT, 2>(y, acc2, alpha, beta); | |
125 assign<eT, 3>(y, acc3, alpha, beta); | |
126 } | |
127 break; | |
128 | |
129 | |
130 default: | |
131 ; | |
132 } | |
133 } | |
134 | |
135 }; | |
136 | |
137 | |
138 | |
139 //! \brief | |
140 //! Partial emulation of ATLAS/BLAS gemv(). | |
141 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) | |
142 | |
143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
144 class gemv_emul_large | |
145 { | |
146 public: | |
147 | |
148 template<typename eT> | |
149 arma_hot | |
150 inline | |
151 static | |
152 void | |
153 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
154 { | |
155 arma_extra_debug_sigprint(); | |
156 | |
157 const uword A_n_rows = A.n_rows; | |
158 const uword A_n_cols = A.n_cols; | |
159 | |
160 if(do_trans_A == false) | |
161 { | |
162 if(A_n_rows == 1) | |
163 { | |
164 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x); | |
165 | |
166 if( (use_alpha == false) && (use_beta == false) ) | |
167 { | |
168 y[0] = acc; | |
169 } | |
170 else | |
171 if( (use_alpha == true) && (use_beta == false) ) | |
172 { | |
173 y[0] = alpha * acc; | |
174 } | |
175 else | |
176 if( (use_alpha == false) && (use_beta == true) ) | |
177 { | |
178 y[0] = acc + beta*y[0]; | |
179 } | |
180 else | |
181 if( (use_alpha == true) && (use_beta == true) ) | |
182 { | |
183 y[0] = alpha*acc + beta*y[0]; | |
184 } | |
185 } | |
186 else | |
187 for(uword row=0; row < A_n_rows; ++row) | |
188 { | |
189 eT acc = eT(0); | |
190 | |
191 for(uword i=0; i < A_n_cols; ++i) | |
192 { | |
193 acc += A.at(row,i) * x[i]; | |
194 } | |
195 | |
196 if( (use_alpha == false) && (use_beta == false) ) | |
197 { | |
198 y[row] = acc; | |
199 } | |
200 else | |
201 if( (use_alpha == true) && (use_beta == false) ) | |
202 { | |
203 y[row] = alpha * acc; | |
204 } | |
205 else | |
206 if( (use_alpha == false) && (use_beta == true) ) | |
207 { | |
208 y[row] = acc + beta*y[row]; | |
209 } | |
210 else | |
211 if( (use_alpha == true) && (use_beta == true) ) | |
212 { | |
213 y[row] = alpha*acc + beta*y[row]; | |
214 } | |
215 } | |
216 } | |
217 else | |
218 if(do_trans_A == true) | |
219 { | |
220 for(uword col=0; col < A_n_cols; ++col) | |
221 { | |
222 // col is interpreted as row when storing the results in 'y' | |
223 | |
224 | |
225 // const eT* A_coldata = A.colptr(col); | |
226 // | |
227 // eT acc = eT(0); | |
228 // for(uword row=0; row < A_n_rows; ++row) | |
229 // { | |
230 // acc += A_coldata[row] * x[row]; | |
231 // } | |
232 | |
233 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x); | |
234 | |
235 if( (use_alpha == false) && (use_beta == false) ) | |
236 { | |
237 y[col] = acc; | |
238 } | |
239 else | |
240 if( (use_alpha == true) && (use_beta == false) ) | |
241 { | |
242 y[col] = alpha * acc; | |
243 } | |
244 else | |
245 if( (use_alpha == false) && (use_beta == true) ) | |
246 { | |
247 y[col] = acc + beta*y[col]; | |
248 } | |
249 else | |
250 if( (use_alpha == true) && (use_beta == true) ) | |
251 { | |
252 y[col] = alpha*acc + beta*y[col]; | |
253 } | |
254 | |
255 } | |
256 } | |
257 } | |
258 | |
259 }; | |
260 | |
261 | |
262 | |
263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
264 class gemv_emul | |
265 { | |
266 public: | |
267 | |
268 template<typename eT> | |
269 arma_hot | |
270 inline | |
271 static | |
272 void | |
273 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 ) | |
274 { | |
275 arma_extra_debug_sigprint(); | |
276 arma_ignore(junk); | |
277 | |
278 const uword A_n_rows = A.n_rows; | |
279 const uword A_n_cols = A.n_cols; | |
280 | |
281 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) ) | |
282 { | |
283 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta); | |
284 } | |
285 else | |
286 { | |
287 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta); | |
288 } | |
289 } | |
290 | |
291 | |
292 | |
293 template<typename eT> | |
294 arma_hot | |
295 inline | |
296 static | |
297 void | |
298 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 ) | |
299 { | |
300 arma_extra_debug_sigprint(); | |
301 | |
302 Mat<eT> tmp_A; | |
303 | |
304 if(do_trans_A) | |
305 { | |
306 op_htrans::apply_noalias(tmp_A, A); | |
307 } | |
308 | |
309 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A; | |
310 | |
311 const uword AA_n_rows = AA.n_rows; | |
312 const uword AA_n_cols = AA.n_cols; | |
313 | |
314 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) ) | |
315 { | |
316 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta); | |
317 } | |
318 else | |
319 { | |
320 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta); | |
321 } | |
322 } | |
323 }; | |
324 | |
325 | |
326 | |
327 //! \brief | |
328 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv. | |
329 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) | |
330 | |
331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false> | |
332 class gemv | |
333 { | |
334 public: | |
335 | |
336 template<typename eT> | |
337 inline | |
338 static | |
339 void | |
340 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
341 { | |
342 arma_extra_debug_sigprint(); | |
343 | |
344 if(A.n_elem <= 64u) | |
345 { | |
346 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
347 } | |
348 else | |
349 { | |
350 #if defined(ARMA_USE_ATLAS) | |
351 { | |
352 arma_extra_debug_print("atlas::cblas_gemv()"); | |
353 | |
354 atlas::cblas_gemv<eT> | |
355 ( | |
356 atlas::CblasColMajor, | |
357 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, | |
358 A.n_rows, | |
359 A.n_cols, | |
360 (use_alpha) ? alpha : eT(1), | |
361 A.mem, | |
362 A.n_rows, | |
363 x, | |
364 1, | |
365 (use_beta) ? beta : eT(0), | |
366 y, | |
367 1 | |
368 ); | |
369 } | |
370 #elif defined(ARMA_USE_BLAS) | |
371 { | |
372 arma_extra_debug_print("blas::gemv()"); | |
373 | |
374 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N'; | |
375 const blas_int m = A.n_rows; | |
376 const blas_int n = A.n_cols; | |
377 const eT local_alpha = (use_alpha) ? alpha : eT(1); | |
378 //const blas_int lda = A.n_rows; | |
379 const blas_int inc = 1; | |
380 const eT local_beta = (use_beta) ? beta : eT(0); | |
381 | |
382 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A ); | |
383 | |
384 blas::gemv<eT> | |
385 ( | |
386 &trans_A, | |
387 &m, | |
388 &n, | |
389 &local_alpha, | |
390 A.mem, | |
391 &m, // lda | |
392 x, | |
393 &inc, | |
394 &local_beta, | |
395 y, | |
396 &inc | |
397 ); | |
398 } | |
399 #else | |
400 { | |
401 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
402 } | |
403 #endif | |
404 } | |
405 | |
406 } | |
407 | |
408 | |
409 | |
410 template<typename eT> | |
411 arma_inline | |
412 static | |
413 void | |
414 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) ) | |
415 { | |
416 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta); | |
417 } | |
418 | |
419 | |
420 | |
421 arma_inline | |
422 static | |
423 void | |
424 apply | |
425 ( | |
426 float* y, | |
427 const Mat<float>& A, | |
428 const float* x, | |
429 const float alpha = float(1), | |
430 const float beta = float(0) | |
431 ) | |
432 { | |
433 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
434 } | |
435 | |
436 | |
437 | |
438 arma_inline | |
439 static | |
440 void | |
441 apply | |
442 ( | |
443 double* y, | |
444 const Mat<double>& A, | |
445 const double* x, | |
446 const double alpha = double(1), | |
447 const double beta = double(0) | |
448 ) | |
449 { | |
450 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
451 } | |
452 | |
453 | |
454 | |
455 arma_inline | |
456 static | |
457 void | |
458 apply | |
459 ( | |
460 std::complex<float>* y, | |
461 const Mat< std::complex<float > >& A, | |
462 const std::complex<float>* x, | |
463 const std::complex<float> alpha = std::complex<float>(1), | |
464 const std::complex<float> beta = std::complex<float>(0) | |
465 ) | |
466 { | |
467 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
468 } | |
469 | |
470 | |
471 | |
472 arma_inline | |
473 static | |
474 void | |
475 apply | |
476 ( | |
477 std::complex<double>* y, | |
478 const Mat< std::complex<double> >& A, | |
479 const std::complex<double>* x, | |
480 const std::complex<double> alpha = std::complex<double>(1), | |
481 const std::complex<double> beta = std::complex<double>(0) | |
482 ) | |
483 { | |
484 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta); | |
485 } | |
486 | |
487 | |
488 | |
489 }; | |
490 | |
491 | |
492 //! @} |