comparison armadillo-2.4.4/include/armadillo_bits/gemv.hpp @ 0:8b6102e2a9b0

Armadillo Library
author maxzanoni76 <max.zanoni@eecs.qmul.ac.uk>
date Wed, 11 Apr 2012 09:27:06 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:8b6102e2a9b0
1 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au)
2 // Copyright (C) 2008-2011 Conrad Sanderson
3 //
4 // This file is part of the Armadillo C++ library.
5 // It is provided without any warranty of fitness
6 // for any purpose. You can redistribute this file
7 // and/or modify it under the terms of the GNU
8 // Lesser General Public License (LGPL) as published
9 // by the Free Software Foundation, either version 3
10 // of the License or (at your option) any later version.
11 // (see http://www.opensource.org/licenses for more info)
12
13
14 //! \addtogroup gemv
15 //! @{
16
17
18
19 //! for tiny square matrices, size <= 4x4
20 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
21 class gemv_emul_tinysq
22 {
23 public:
24
25
26 template<const uword row, const uword col>
27 struct pos
28 {
29 static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2);
30 static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3);
31 static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4);
32 };
33
34
35
36 template<typename eT, const uword i>
37 arma_hot
38 arma_inline
39 static
40 void
41 assign(eT* y, const eT acc, const eT alpha, const eT beta)
42 {
43 if(use_beta == false)
44 {
45 y[i] = (use_alpha == false) ? acc : alpha*acc;
46 }
47 else
48 {
49 const eT tmp = y[i];
50
51 y[i] = beta*tmp + ( (use_alpha == false) ? acc : alpha*acc );
52 }
53 }
54
55
56
57 template<typename eT>
58 arma_hot
59 inline
60 static
61 void
62 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
63 {
64 arma_extra_debug_sigprint();
65
66 const eT* Am = A.memptr();
67
68 switch(A.n_rows)
69 {
70 case 1:
71 {
72 const eT acc = Am[0] * x[0];
73
74 assign<eT, 0>(y, acc, alpha, beta);
75 }
76 break;
77
78
79 case 2:
80 {
81 const eT x0 = x[0];
82 const eT x1 = x[1];
83
84 const eT acc0 = Am[pos<0,0>::n2]*x0 + Am[pos<0,1>::n2]*x1;
85 const eT acc1 = Am[pos<1,0>::n2]*x0 + Am[pos<1,1>::n2]*x1;
86
87 assign<eT, 0>(y, acc0, alpha, beta);
88 assign<eT, 1>(y, acc1, alpha, beta);
89 }
90 break;
91
92
93 case 3:
94 {
95 const eT x0 = x[0];
96 const eT x1 = x[1];
97 const eT x2 = x[2];
98
99 const eT acc0 = Am[pos<0,0>::n3]*x0 + Am[pos<0,1>::n3]*x1 + Am[pos<0,2>::n3]*x2;
100 const eT acc1 = Am[pos<1,0>::n3]*x0 + Am[pos<1,1>::n3]*x1 + Am[pos<1,2>::n3]*x2;
101 const eT acc2 = Am[pos<2,0>::n3]*x0 + Am[pos<2,1>::n3]*x1 + Am[pos<2,2>::n3]*x2;
102
103 assign<eT, 0>(y, acc0, alpha, beta);
104 assign<eT, 1>(y, acc1, alpha, beta);
105 assign<eT, 2>(y, acc2, alpha, beta);
106 }
107 break;
108
109
110 case 4:
111 {
112 const eT x0 = x[0];
113 const eT x1 = x[1];
114 const eT x2 = x[2];
115 const eT x3 = x[3];
116
117 const eT acc0 = Am[pos<0,0>::n4]*x0 + Am[pos<0,1>::n4]*x1 + Am[pos<0,2>::n4]*x2 + Am[pos<0,3>::n4]*x3;
118 const eT acc1 = Am[pos<1,0>::n4]*x0 + Am[pos<1,1>::n4]*x1 + Am[pos<1,2>::n4]*x2 + Am[pos<1,3>::n4]*x3;
119 const eT acc2 = Am[pos<2,0>::n4]*x0 + Am[pos<2,1>::n4]*x1 + Am[pos<2,2>::n4]*x2 + Am[pos<2,3>::n4]*x3;
120 const eT acc3 = Am[pos<3,0>::n4]*x0 + Am[pos<3,1>::n4]*x1 + Am[pos<3,2>::n4]*x2 + Am[pos<3,3>::n4]*x3;
121
122 assign<eT, 0>(y, acc0, alpha, beta);
123 assign<eT, 1>(y, acc1, alpha, beta);
124 assign<eT, 2>(y, acc2, alpha, beta);
125 assign<eT, 3>(y, acc3, alpha, beta);
126 }
127 break;
128
129
130 default:
131 ;
132 }
133 }
134
135 };
136
137
138
139 //! \brief
140 //! Partial emulation of ATLAS/BLAS gemv().
141 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
142
143 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
144 class gemv_emul_large
145 {
146 public:
147
148 template<typename eT>
149 arma_hot
150 inline
151 static
152 void
153 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
154 {
155 arma_extra_debug_sigprint();
156
157 const uword A_n_rows = A.n_rows;
158 const uword A_n_cols = A.n_cols;
159
160 if(do_trans_A == false)
161 {
162 if(A_n_rows == 1)
163 {
164 const eT acc = op_dot::direct_dot_arma(A_n_cols, A.mem, x);
165
166 if( (use_alpha == false) && (use_beta == false) )
167 {
168 y[0] = acc;
169 }
170 else
171 if( (use_alpha == true) && (use_beta == false) )
172 {
173 y[0] = alpha * acc;
174 }
175 else
176 if( (use_alpha == false) && (use_beta == true) )
177 {
178 y[0] = acc + beta*y[0];
179 }
180 else
181 if( (use_alpha == true) && (use_beta == true) )
182 {
183 y[0] = alpha*acc + beta*y[0];
184 }
185 }
186 else
187 for(uword row=0; row < A_n_rows; ++row)
188 {
189 eT acc = eT(0);
190
191 for(uword i=0; i < A_n_cols; ++i)
192 {
193 acc += A.at(row,i) * x[i];
194 }
195
196 if( (use_alpha == false) && (use_beta == false) )
197 {
198 y[row] = acc;
199 }
200 else
201 if( (use_alpha == true) && (use_beta == false) )
202 {
203 y[row] = alpha * acc;
204 }
205 else
206 if( (use_alpha == false) && (use_beta == true) )
207 {
208 y[row] = acc + beta*y[row];
209 }
210 else
211 if( (use_alpha == true) && (use_beta == true) )
212 {
213 y[row] = alpha*acc + beta*y[row];
214 }
215 }
216 }
217 else
218 if(do_trans_A == true)
219 {
220 for(uword col=0; col < A_n_cols; ++col)
221 {
222 // col is interpreted as row when storing the results in 'y'
223
224
225 // const eT* A_coldata = A.colptr(col);
226 //
227 // eT acc = eT(0);
228 // for(uword row=0; row < A_n_rows; ++row)
229 // {
230 // acc += A_coldata[row] * x[row];
231 // }
232
233 const eT acc = op_dot::direct_dot_arma(A_n_rows, A.colptr(col), x);
234
235 if( (use_alpha == false) && (use_beta == false) )
236 {
237 y[col] = acc;
238 }
239 else
240 if( (use_alpha == true) && (use_beta == false) )
241 {
242 y[col] = alpha * acc;
243 }
244 else
245 if( (use_alpha == false) && (use_beta == true) )
246 {
247 y[col] = acc + beta*y[col];
248 }
249 else
250 if( (use_alpha == true) && (use_beta == true) )
251 {
252 y[col] = alpha*acc + beta*y[col];
253 }
254
255 }
256 }
257 }
258
259 };
260
261
262
263 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
264 class gemv_emul
265 {
266 public:
267
268 template<typename eT>
269 arma_hot
270 inline
271 static
272 void
273 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx<eT>::result* junk = 0 )
274 {
275 arma_extra_debug_sigprint();
276 arma_ignore(junk);
277
278 const uword A_n_rows = A.n_rows;
279 const uword A_n_cols = A.n_cols;
280
281 if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) )
282 {
283 gemv_emul_tinysq<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
284 }
285 else
286 {
287 gemv_emul_large<do_trans_A, use_alpha, use_beta>::apply(y, A, x, alpha, beta);
288 }
289 }
290
291
292
293 template<typename eT>
294 arma_hot
295 inline
296 static
297 void
298 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_cx_only<eT>::result* junk = 0 )
299 {
300 arma_extra_debug_sigprint();
301
302 Mat<eT> tmp_A;
303
304 if(do_trans_A)
305 {
306 op_htrans::apply_noalias(tmp_A, A);
307 }
308
309 const Mat<eT>& AA = (do_trans_A == false) ? A : tmp_A;
310
311 const uword AA_n_rows = AA.n_rows;
312 const uword AA_n_cols = AA.n_cols;
313
314 if( (AA_n_rows <= 4) && (AA_n_rows == AA_n_cols) )
315 {
316 gemv_emul_tinysq<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
317 }
318 else
319 {
320 gemv_emul_large<false, use_alpha, use_beta>::apply(y, AA, x, alpha, beta);
321 }
322 }
323 };
324
325
326
327 //! \brief
328 //! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv.
329 //! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose)
330
331 template<const bool do_trans_A=false, const bool use_alpha=false, const bool use_beta=false>
332 class gemv
333 {
334 public:
335
336 template<typename eT>
337 inline
338 static
339 void
340 apply_blas_type( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
341 {
342 arma_extra_debug_sigprint();
343
344 if(A.n_elem <= 64u)
345 {
346 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
347 }
348 else
349 {
350 #if defined(ARMA_USE_ATLAS)
351 {
352 arma_extra_debug_print("atlas::cblas_gemv()");
353
354 atlas::cblas_gemv<eT>
355 (
356 atlas::CblasColMajor,
357 (do_trans_A) ? ( is_complex<eT>::value ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans,
358 A.n_rows,
359 A.n_cols,
360 (use_alpha) ? alpha : eT(1),
361 A.mem,
362 A.n_rows,
363 x,
364 1,
365 (use_beta) ? beta : eT(0),
366 y,
367 1
368 );
369 }
370 #elif defined(ARMA_USE_BLAS)
371 {
372 arma_extra_debug_print("blas::gemv()");
373
374 const char trans_A = (do_trans_A) ? ( is_complex<eT>::value ? 'C' : 'T' ) : 'N';
375 const blas_int m = A.n_rows;
376 const blas_int n = A.n_cols;
377 const eT local_alpha = (use_alpha) ? alpha : eT(1);
378 //const blas_int lda = A.n_rows;
379 const blas_int inc = 1;
380 const eT local_beta = (use_beta) ? beta : eT(0);
381
382 arma_extra_debug_print( arma_boost::format("blas::gemv(): trans_A = %c") % trans_A );
383
384 blas::gemv<eT>
385 (
386 &trans_A,
387 &m,
388 &n,
389 &local_alpha,
390 A.mem,
391 &m, // lda
392 x,
393 &inc,
394 &local_beta,
395 y,
396 &inc
397 );
398 }
399 #else
400 {
401 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
402 }
403 #endif
404 }
405
406 }
407
408
409
410 template<typename eT>
411 arma_inline
412 static
413 void
414 apply( eT* y, const Mat<eT>& A, const eT* x, const eT alpha = eT(1), const eT beta = eT(0) )
415 {
416 gemv_emul<do_trans_A, use_alpha, use_beta>::apply(y,A,x,alpha,beta);
417 }
418
419
420
421 arma_inline
422 static
423 void
424 apply
425 (
426 float* y,
427 const Mat<float>& A,
428 const float* x,
429 const float alpha = float(1),
430 const float beta = float(0)
431 )
432 {
433 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
434 }
435
436
437
438 arma_inline
439 static
440 void
441 apply
442 (
443 double* y,
444 const Mat<double>& A,
445 const double* x,
446 const double alpha = double(1),
447 const double beta = double(0)
448 )
449 {
450 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
451 }
452
453
454
455 arma_inline
456 static
457 void
458 apply
459 (
460 std::complex<float>* y,
461 const Mat< std::complex<float > >& A,
462 const std::complex<float>* x,
463 const std::complex<float> alpha = std::complex<float>(1),
464 const std::complex<float> beta = std::complex<float>(0)
465 )
466 {
467 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
468 }
469
470
471
472 arma_inline
473 static
474 void
475 apply
476 (
477 std::complex<double>* y,
478 const Mat< std::complex<double> >& A,
479 const std::complex<double>* x,
480 const std::complex<double> alpha = std::complex<double>(1),
481 const std::complex<double> beta = std::complex<double>(0)
482 )
483 {
484 gemv<do_trans_A, use_alpha, use_beta>::apply_blas_type(y,A,x,alpha,beta);
485 }
486
487
488
489 };
490
491
492 //! @}